updates
This commit is contained in:
parent
0884cc00e5
commit
37715f6261
7 changed files with 268 additions and 66 deletions
|
|
@ -25,6 +25,7 @@ from pydantic import BaseModel, Field
|
||||||
from transcript_parser import parse_podcast_messages
|
from transcript_parser import parse_podcast_messages
|
||||||
|
|
||||||
from graphiti_core import Graphiti
|
from graphiti_core import Graphiti
|
||||||
|
from graphiti_core.driver.neo4j_driver import Neo4jDriver
|
||||||
from graphiti_core.nodes import EpisodeType
|
from graphiti_core.nodes import EpisodeType
|
||||||
from graphiti_core.utils.bulk_utils import RawEpisode
|
from graphiti_core.utils.bulk_utils import RawEpisode
|
||||||
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
|
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
|
||||||
|
|
@ -34,6 +35,8 @@ load_dotenv()
|
||||||
neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687'
|
neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687'
|
||||||
neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j'
|
neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j'
|
||||||
neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password'
|
neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password'
|
||||||
|
aoss_host = os.environ.get('AOSS_HOST') or None
|
||||||
|
aoss_port = os.environ.get('AOSS_PORT') or None
|
||||||
|
|
||||||
|
|
||||||
def setup_logging():
|
def setup_logging():
|
||||||
|
|
@ -77,11 +80,16 @@ class IsPresidentOf(BaseModel):
|
||||||
|
|
||||||
async def main(use_bulk: bool = False):
|
async def main(use_bulk: bool = False):
|
||||||
setup_logging()
|
setup_logging()
|
||||||
client = Graphiti(
|
graph_driver = Neo4jDriver(
|
||||||
neo4j_uri,
|
neo4j_uri, neo4j_user, neo4j_password, aoss_host=aoss_host, aoss_port=aoss_port
|
||||||
neo4j_user,
|
|
||||||
neo4j_password,
|
|
||||||
)
|
)
|
||||||
|
# client = Graphiti(
|
||||||
|
# neo4j_uri,
|
||||||
|
# neo4j_user,
|
||||||
|
# neo4j_password,
|
||||||
|
# )
|
||||||
|
client = Graphiti(graph_driver=graph_driver)
|
||||||
|
await client.driver.create_aoss_indices()
|
||||||
await clear_data(client.driver)
|
await clear_data(client.driver)
|
||||||
await client.build_indices_and_constraints()
|
await client.build_indices_and_constraints()
|
||||||
messages = parse_podcast_messages()
|
messages = parse_podcast_messages()
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,7 @@ class GraphProvider(Enum):
|
||||||
|
|
||||||
aoss_indices = [
|
aoss_indices = [
|
||||||
{
|
{
|
||||||
'index_name': 'entities',
|
'index_name': 'entities_test',
|
||||||
'body': {
|
'body': {
|
||||||
'mappings': {
|
'mappings': {
|
||||||
'properties': {
|
'properties': {
|
||||||
|
|
@ -74,7 +74,7 @@ aoss_indices = [
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'index_name': 'communities',
|
'index_name': 'communities_test',
|
||||||
'body': {
|
'body': {
|
||||||
'mappings': {
|
'mappings': {
|
||||||
'properties': {
|
'properties': {
|
||||||
|
|
@ -102,7 +102,7 @@ aoss_indices = [
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'index_name': 'entity_edges',
|
'index_name': 'entity_edges_test',
|
||||||
'body': {
|
'body': {
|
||||||
'mappings': {
|
'mappings': {
|
||||||
'properties': {
|
'properties': {
|
||||||
|
|
@ -233,6 +233,31 @@ class GraphDriver(ABC):
|
||||||
if client.indices.exists(index=index_name):
|
if client.indices.exists(index=index_name):
|
||||||
client.indices.delete(index=index_name)
|
client.indices.delete(index=index_name)
|
||||||
|
|
||||||
|
async def clear_aoss_indices(self):
|
||||||
|
client = self.aoss_client
|
||||||
|
|
||||||
|
if not client:
|
||||||
|
logger.warning('No OpenSearch client found')
|
||||||
|
return
|
||||||
|
|
||||||
|
for index in aoss_indices:
|
||||||
|
index_name = index['index_name']
|
||||||
|
|
||||||
|
if client.indices.exists(index=index_name):
|
||||||
|
try:
|
||||||
|
# Delete all documents but keep the index
|
||||||
|
response = client.delete_by_query(
|
||||||
|
index=index_name,
|
||||||
|
body={'query': {'match_all': {}}},
|
||||||
|
refresh=True,
|
||||||
|
conflicts='proceed',
|
||||||
|
)
|
||||||
|
logger.info(f"Cleared index '{index_name}': {response}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error clearing index '{index_name}': {e}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Index '{index_name}' does not exist")
|
||||||
|
|
||||||
def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
||||||
client = self.aoss_client
|
client = self.aoss_client
|
||||||
if not client or not helpers:
|
if not client or not helpers:
|
||||||
|
|
|
||||||
|
|
@ -63,8 +63,8 @@ class Neo4jDriver(GraphDriver):
|
||||||
try:
|
try:
|
||||||
session = boto3.Session()
|
session = boto3.Session()
|
||||||
self.aoss_client = OpenSearch( # type: ignore
|
self.aoss_client = OpenSearch( # type: ignore
|
||||||
hosts=[{'host': aoss_host, 'port': aoss_port}],
|
hosts=[{'host': aoss_host, 'port': aoss_port, 'scheme': 'https'}],
|
||||||
http_auth=Urllib3AWSV4SignerAuth( # type: ignore
|
http_auth=Urllib3AWSV4SignerAuth(
|
||||||
session.get_credentials(), session.region_name, 'aoss'
|
session.get_credentials(), session.region_name, 'aoss'
|
||||||
),
|
),
|
||||||
use_ssl=True,
|
use_ssl=True,
|
||||||
|
|
|
||||||
|
|
@ -77,6 +77,11 @@ class Edge(BaseModel, ABC):
|
||||||
uuid=self.uuid,
|
uuid=self.uuid,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if driver.aoss_client:
|
||||||
|
await driver.aoss_client.delete(
|
||||||
|
index='entity_edges', id=self.uuid, routing=self.group_id
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug(f'Deleted Edge: {self.uuid}')
|
logger.debug(f'Deleted Edge: {self.uuid}')
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -108,6 +113,12 @@ class Edge(BaseModel, ABC):
|
||||||
uuids=uuids,
|
uuids=uuids,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if driver.aoss_client:
|
||||||
|
await driver.aoss_client.delete_by_query(
|
||||||
|
index='entity_edges',
|
||||||
|
body={'query': {'terms': {'uuid': uuids}}},
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug(f'Deleted Edges: {uuids}')
|
logger.debug(f'Deleted Edges: {uuids}')
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
|
|
@ -351,6 +362,35 @@ class EntityEdge(Edge):
|
||||||
raise EdgeNotFoundError(uuid)
|
raise EdgeNotFoundError(uuid)
|
||||||
return edges[0]
|
return edges[0]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_between_nodes(
|
||||||
|
cls, driver: GraphDriver, source_node_uuid: str, target_node_uuid: str
|
||||||
|
):
|
||||||
|
match_query = """
|
||||||
|
MATCH (n:Entity {uuid: $source_node_uuid})-[e:RELATES_TO]->(m:Entity {uuid: $target_node_uuid})
|
||||||
|
"""
|
||||||
|
if driver.provider == GraphProvider.KUZU:
|
||||||
|
match_query = """
|
||||||
|
MATCH (n:Entity {uuid: $source_node_uuid})
|
||||||
|
-[:RELATES_TO]->(e:RelatesToNode_)
|
||||||
|
-[:RELATES_TO]->(m:Entity {uuid: $target_node_uuid})
|
||||||
|
"""
|
||||||
|
|
||||||
|
records, _, _ = await driver.execute_query(
|
||||||
|
match_query
|
||||||
|
+ """
|
||||||
|
RETURN
|
||||||
|
"""
|
||||||
|
+ get_entity_edge_return_query(driver.provider),
|
||||||
|
source_node_uuid=source_node_uuid,
|
||||||
|
target_node_uuid=target_node_uuid,
|
||||||
|
routing_='r',
|
||||||
|
)
|
||||||
|
|
||||||
|
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
|
||||||
|
|
||||||
|
return edges
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
||||||
if len(uuids) == 0:
|
if len(uuids) == 0:
|
||||||
|
|
|
||||||
|
|
@ -94,13 +94,35 @@ class Node(BaseModel, ABC):
|
||||||
async def delete(self, driver: GraphDriver):
|
async def delete(self, driver: GraphDriver):
|
||||||
match driver.provider:
|
match driver.provider:
|
||||||
case GraphProvider.NEO4J:
|
case GraphProvider.NEO4J:
|
||||||
await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity|Episodic|Community {uuid: $uuid})
|
MATCH (n:Entity|Episodic|Community {uuid: $uuid})-[r]-()
|
||||||
|
WITH collect(r.uuid) AS edge_uuids, n
|
||||||
DETACH DELETE n
|
DETACH DELETE n
|
||||||
|
RETURN edge_uuids
|
||||||
""",
|
""",
|
||||||
uuid=self.uuid,
|
uuid=self.uuid,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
edge_uuids: list[str] = []
|
||||||
|
if result and result[0].get('edge_uuids'):
|
||||||
|
edge_uuids = result[0]['edge_uuids']
|
||||||
|
|
||||||
|
if driver.aoss_client:
|
||||||
|
# Delete the node from OpenSearch indices
|
||||||
|
for index in ('episodes', 'entities', 'communities'):
|
||||||
|
await driver.aoss_client.delete(
|
||||||
|
index=index, id=self.uuid, routing=self.group_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Bulk delete the detached edges
|
||||||
|
if edge_uuids:
|
||||||
|
actions = []
|
||||||
|
for eid in edge_uuids:
|
||||||
|
actions.append({'delete': {'_index': 'entity_edges', '_id': eid}})
|
||||||
|
|
||||||
|
await driver.aoss_client.bulk(body=actions)
|
||||||
|
|
||||||
case GraphProvider.KUZU:
|
case GraphProvider.KUZU:
|
||||||
for label in ['Episodic', 'Community']:
|
for label in ['Episodic', 'Community']:
|
||||||
await driver.execute_query(
|
await driver.execute_query(
|
||||||
|
|
@ -162,6 +184,32 @@ class Node(BaseModel, ABC):
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if driver.aoss_client:
|
||||||
|
await driver.aoss_client.delete_by_query(
|
||||||
|
index='episodes',
|
||||||
|
body={'query': {'term': {'group_id': group_id}}},
|
||||||
|
routing=group_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
await driver.aoss_client.delete_by_query(
|
||||||
|
index='entities',
|
||||||
|
body={'query': {'term': {'group_id': group_id}}},
|
||||||
|
routing=group_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
await driver.aoss_client.delete_by_query(
|
||||||
|
index='communities',
|
||||||
|
body={'query': {'term': {'group_id': group_id}}},
|
||||||
|
routing=group_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
await driver.aoss_client.delete_by_query(
|
||||||
|
index='entity_edges',
|
||||||
|
body={'query': {'term': {'group_id': group_id}}},
|
||||||
|
routing=group_id,
|
||||||
|
)
|
||||||
|
|
||||||
case GraphProvider.KUZU:
|
case GraphProvider.KUZU:
|
||||||
for label in ['Episodic', 'Community']:
|
for label in ['Episodic', 'Community']:
|
||||||
await driver.execute_query(
|
await driver.execute_query(
|
||||||
|
|
@ -240,6 +288,23 @@ class Node(BaseModel, ABC):
|
||||||
)
|
)
|
||||||
case _: # Neo4J, Neptune
|
case _: # Neo4J, Neptune
|
||||||
async with driver.session() as session:
|
async with driver.session() as session:
|
||||||
|
# Collect all edge UUIDs before deleting nodes
|
||||||
|
result = await session.run(
|
||||||
|
"""
|
||||||
|
MATCH (n:Entity|Episodic|Community)
|
||||||
|
WHERE n.uuid IN $uuids
|
||||||
|
MATCH (n)-[r]-()
|
||||||
|
RETURN collect(r.uuid) AS edgeUuids
|
||||||
|
""",
|
||||||
|
uuids=uuids,
|
||||||
|
)
|
||||||
|
|
||||||
|
record = await result.single()
|
||||||
|
edge_uuids: list[str] = (
|
||||||
|
record['edgeUuids'] if record and record['edgeUuids'] else []
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now delete the nodes in batches
|
||||||
await session.run(
|
await session.run(
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity|Episodic|Community)
|
MATCH (n:Entity|Episodic|Community)
|
||||||
|
|
@ -253,6 +318,19 @@ class Node(BaseModel, ABC):
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if driver.aoss_client:
|
||||||
|
for index in ('episodes', 'entities', 'communities'):
|
||||||
|
await driver.aoss_client.delete_by_query(
|
||||||
|
index=index,
|
||||||
|
body={'query': {'terms': {'uuid': uuids}}},
|
||||||
|
)
|
||||||
|
|
||||||
|
if edge_uuids:
|
||||||
|
actions = [
|
||||||
|
{'delete': {'_index': 'entity_edges', '_id': eid}} for eid in edge_uuids
|
||||||
|
]
|
||||||
|
await driver.aoss_client.bulk(body=actions)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
|
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -209,11 +209,11 @@ async def edge_fulltext_search(
|
||||||
# Match the edge ids and return the values
|
# Match the edge ids and return the values
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $ids as id
|
UNWIND $ids as id
|
||||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||||
WHERE e.group_id IN $group_ids
|
WHERE e.group_id IN $group_ids
|
||||||
AND id(e)=id
|
AND id(e)=id
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
AND id(e)=id
|
AND id(e)=id
|
||||||
|
|
@ -344,8 +344,8 @@ async def edge_similarity_search(
|
||||||
if driver.provider == GraphProvider.NEPTUNE:
|
if driver.provider == GraphProvider.NEPTUNE:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
|
RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
|
||||||
|
|
@ -622,11 +622,11 @@ async def node_fulltext_search(
|
||||||
# Match the edge ides and return the values
|
# Match the edge ides and return the values
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $ids as i
|
UNWIND $ids as i
|
||||||
MATCH (n:Entity)
|
MATCH (n:Entity)
|
||||||
WHERE n.uuid=i.id
|
WHERE n.uuid=i.id
|
||||||
RETURN
|
RETURN
|
||||||
"""
|
"""
|
||||||
+ get_entity_node_return_query(driver.provider)
|
+ get_entity_node_return_query(driver.provider)
|
||||||
+ """
|
+ """
|
||||||
ORDER BY i.score DESC
|
ORDER BY i.score DESC
|
||||||
|
|
@ -734,8 +734,8 @@ async def node_similarity_search(
|
||||||
if driver.provider == GraphProvider.NEPTUNE:
|
if driver.provider == GraphProvider.NEPTUNE:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity)
|
MATCH (n:Entity)
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
||||||
|
|
@ -764,11 +764,11 @@ async def node_similarity_search(
|
||||||
# Match the edge ides and return the values
|
# Match the edge ides and return the values
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $ids as i
|
UNWIND $ids as i
|
||||||
MATCH (n:Entity)
|
MATCH (n:Entity)
|
||||||
WHERE id(n)=i.id
|
WHERE id(n)=i.id
|
||||||
RETURN
|
RETURN
|
||||||
"""
|
"""
|
||||||
+ get_entity_node_return_query(driver.provider)
|
+ get_entity_node_return_query(driver.provider)
|
||||||
+ """
|
+ """
|
||||||
ORDER BY i.score DESC
|
ORDER BY i.score DESC
|
||||||
|
|
@ -814,8 +814,8 @@ async def node_similarity_search(
|
||||||
else:
|
else:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity)
|
MATCH (n:Entity)
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH n, """
|
WITH n, """
|
||||||
|
|
@ -1147,8 +1147,8 @@ async def community_similarity_search(
|
||||||
if driver.provider == GraphProvider.NEPTUNE:
|
if driver.provider == GraphProvider.NEPTUNE:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
MATCH (n:Community)
|
MATCH (n:Community)
|
||||||
"""
|
"""
|
||||||
+ group_filter_query
|
+ group_filter_query
|
||||||
+ """
|
+ """
|
||||||
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
||||||
|
|
@ -1207,8 +1207,8 @@ async def community_similarity_search(
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
MATCH (c:Community)
|
MATCH (c:Community)
|
||||||
"""
|
"""
|
||||||
+ group_filter_query
|
+ group_filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH c,
|
WITH c,
|
||||||
|
|
@ -1350,9 +1350,9 @@ async def get_relevant_nodes(
|
||||||
# FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
|
# FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $nodes AS node
|
UNWIND $nodes AS node
|
||||||
MATCH (n:Entity {group_id: $group_id})
|
MATCH (n:Entity {group_id: $group_id})
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH node, n, """
|
WITH node, n, """
|
||||||
|
|
@ -1397,9 +1397,9 @@ async def get_relevant_nodes(
|
||||||
else:
|
else:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $nodes AS node
|
UNWIND $nodes AS node
|
||||||
MATCH (n:Entity {group_id: $group_id})
|
MATCH (n:Entity {group_id: $group_id})
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH node, n, """
|
WITH node, n, """
|
||||||
|
|
@ -1488,9 +1488,9 @@ async def get_relevant_edges(
|
||||||
if driver.provider == GraphProvider.NEPTUNE:
|
if driver.provider == GraphProvider.NEPTUNE:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $edges AS edge
|
UNWIND $edges AS edge
|
||||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH e, edge
|
WITH e, edge
|
||||||
|
|
@ -1560,9 +1560,9 @@ async def get_relevant_edges(
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $edges AS edge
|
UNWIND $edges AS edge
|
||||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
|
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH e, edge, n, m, """
|
WITH e, edge, n, m, """
|
||||||
|
|
@ -1595,12 +1595,61 @@ async def get_relevant_edges(
|
||||||
}) AS matches
|
}) AS matches
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
elif driver.aoss_client:
|
||||||
|
# First get edge candidates
|
||||||
|
query = (
|
||||||
|
"""
|
||||||
|
UNWIND $edges AS edge
|
||||||
|
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||||
|
"""
|
||||||
|
+ filter_query
|
||||||
|
+ """
|
||||||
|
RETURN
|
||||||
|
e.uuid AS search_edge_uuid,
|
||||||
|
collect({
|
||||||
|
uuid: e.uuid,
|
||||||
|
source_node_uuid: startNode(e).uuid,
|
||||||
|
target_node_uuid: endNode(e).uuid,
|
||||||
|
created_at: e.created_at,
|
||||||
|
name: e.name,
|
||||||
|
group_id: e.group_id,
|
||||||
|
fact: e.fact,
|
||||||
|
fact_embedding: e.fact_embedding,
|
||||||
|
episodes: e.episodes,
|
||||||
|
expired_at: e.expired_at,
|
||||||
|
valid_at: e.valid_at,
|
||||||
|
invalid_at: e.invalid_at,
|
||||||
|
attributes: properties(e)
|
||||||
|
}) AS matches
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
results, _, _ = await driver.execute_query(
|
||||||
|
query,
|
||||||
|
edges=[edge.model_dump() for edge in edges],
|
||||||
|
limit=limit,
|
||||||
|
min_score=min_score,
|
||||||
|
routing_='r',
|
||||||
|
**filter_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
relevant_edges_dict: dict[str, list[EntityEdge]] = {
|
||||||
|
result['search_edge_uuid']: [
|
||||||
|
get_entity_edge_from_record(record, driver.provider)
|
||||||
|
for record in result['matches']
|
||||||
|
]
|
||||||
|
for result in results
|
||||||
|
}
|
||||||
|
|
||||||
|
group_id = edges[0].group_id
|
||||||
|
# semaphore_gather(*[edge_similarity_search(driver, )])
|
||||||
|
|
||||||
else:
|
else:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $edges AS edge
|
UNWIND $edges AS edge
|
||||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH e, edge, """
|
WITH e, edge, """
|
||||||
|
|
@ -1673,10 +1722,10 @@ async def get_edge_invalidation_candidates(
|
||||||
if driver.provider == GraphProvider.NEPTUNE:
|
if driver.provider == GraphProvider.NEPTUNE:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $edges AS edge
|
UNWIND $edges AS edge
|
||||||
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
||||||
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH e, edge
|
WITH e, edge
|
||||||
|
|
@ -1746,10 +1795,10 @@ async def get_edge_invalidation_candidates(
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $edges AS edge
|
UNWIND $edges AS edge
|
||||||
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
|
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
|
||||||
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
|
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH edge, e, n, m, """
|
WITH edge, e, n, m, """
|
||||||
|
|
@ -1785,10 +1834,10 @@ async def get_edge_invalidation_candidates(
|
||||||
else:
|
else:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $edges AS edge
|
UNWIND $edges AS edge
|
||||||
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
||||||
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH edge, e, """
|
WITH edge, e, """
|
||||||
|
|
|
||||||
|
|
@ -95,6 +95,8 @@ async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None):
|
||||||
|
|
||||||
async def delete_all(tx):
|
async def delete_all(tx):
|
||||||
await tx.run('MATCH (n) DETACH DELETE n')
|
await tx.run('MATCH (n) DETACH DELETE n')
|
||||||
|
if driver.aoss_client:
|
||||||
|
await driver.clear_aoss_indices()
|
||||||
|
|
||||||
async def delete_group_ids(tx):
|
async def delete_group_ids(tx):
|
||||||
labels = ['Entity', 'Episodic', 'Community']
|
labels = ['Entity', 'Episodic', 'Community']
|
||||||
|
|
@ -151,9 +153,9 @@ async def retrieve_episodes(
|
||||||
|
|
||||||
query: LiteralString = (
|
query: LiteralString = (
|
||||||
"""
|
"""
|
||||||
MATCH (e:Episodic)
|
MATCH (e:Episodic)
|
||||||
WHERE e.valid_at <= $reference_time
|
WHERE e.valid_at <= $reference_time
|
||||||
"""
|
"""
|
||||||
+ query_filter
|
+ query_filter
|
||||||
+ """
|
+ """
|
||||||
RETURN
|
RETURN
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue