updates
This commit is contained in:
parent
2f8b25acf7
commit
28d1adf631
7 changed files with 199 additions and 155 deletions
|
|
@ -80,22 +80,11 @@ class IsPresidentOf(BaseModel):
|
||||||
|
|
||||||
async def main(use_bulk: bool = False):
|
async def main(use_bulk: bool = False):
|
||||||
setup_logging()
|
setup_logging()
|
||||||
graph_driver = Neo4jDriver(
|
client = Graphiti(
|
||||||
neo4j_uri,
|
neo4j_uri,
|
||||||
neo4j_user,
|
neo4j_user,
|
||||||
neo4j_password,
|
neo4j_password,
|
||||||
aoss_host=aoss_host,
|
|
||||||
aoss_port=aoss_port,
|
|
||||||
region='us-west-2',
|
|
||||||
service='es',
|
|
||||||
)
|
)
|
||||||
# client = Graphiti(
|
|
||||||
# neo4j_uri,
|
|
||||||
# neo4j_user,
|
|
||||||
# neo4j_password,
|
|
||||||
# )
|
|
||||||
client = Graphiti(graph_driver=graph_driver)
|
|
||||||
await client.driver.create_aoss_indices()
|
|
||||||
await clear_data(client.driver)
|
await clear_data(client.driver)
|
||||||
await client.build_indices_and_constraints()
|
await client.build_indices_and_constraints()
|
||||||
messages = parse_podcast_messages()
|
messages = parse_podcast_messages()
|
||||||
|
|
|
||||||
|
|
@ -38,6 +38,11 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
DEFAULT_SIZE = 10
|
DEFAULT_SIZE = 10
|
||||||
|
|
||||||
|
EPISODE_INDEX_NAME = 'episodes-test'
|
||||||
|
ENTTITY_INDEX_NAME = 'entities_test'
|
||||||
|
COMMUNITY_INDEX_NAME = 'communities-test'
|
||||||
|
ENTITY_EDGE_INDEX_NAME = 'entity_edges_test'
|
||||||
|
|
||||||
|
|
||||||
class GraphProvider(Enum):
|
class GraphProvider(Enum):
|
||||||
NEO4J = 'neo4j'
|
NEO4J = 'neo4j'
|
||||||
|
|
@ -48,20 +53,19 @@ class GraphProvider(Enum):
|
||||||
|
|
||||||
aoss_indices = [
|
aoss_indices = [
|
||||||
{
|
{
|
||||||
'index_name': 'entities_test',
|
'index_name': ENTTITY_INDEX_NAME,
|
||||||
'body': {
|
'body': {
|
||||||
|
'settings': {'index': {'knn': True}},
|
||||||
'mappings': {
|
'mappings': {
|
||||||
'properties': {
|
'properties': {
|
||||||
'uuid': {'type': 'keyword'},
|
'uuid': {'type': 'keyword'},
|
||||||
'name': {'type': 'text'},
|
'name': {'type': 'text'},
|
||||||
'summary': {'type': 'text'},
|
'summary': {'type': 'text'},
|
||||||
'group_id': {'type': 'text'},
|
'group_id': {'type': 'text'},
|
||||||
'created_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
|
'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
||||||
'name_embedding': {
|
'name_embedding': {
|
||||||
'type': 'knn_vector',
|
'type': 'knn_vector',
|
||||||
'dims': EMBEDDING_DIM,
|
'dimension': EMBEDDING_DIM,
|
||||||
'index': True,
|
|
||||||
'similarity': 'cosine',
|
|
||||||
'method': {
|
'method': {
|
||||||
'engine': 'faiss',
|
'engine': 'faiss',
|
||||||
'space_type': 'cosinesimil',
|
'space_type': 'cosinesimil',
|
||||||
|
|
@ -70,11 +74,11 @@ aoss_indices = [
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'index_name': 'communities_test',
|
'index_name': COMMUNITY_INDEX_NAME,
|
||||||
'body': {
|
'body': {
|
||||||
'mappings': {
|
'mappings': {
|
||||||
'properties': {
|
'properties': {
|
||||||
|
|
@ -86,7 +90,7 @@ aoss_indices = [
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'index_name': 'episodes',
|
'index_name': EPISODE_INDEX_NAME,
|
||||||
'body': {
|
'body': {
|
||||||
'mappings': {
|
'mappings': {
|
||||||
'properties': {
|
'properties': {
|
||||||
|
|
@ -95,30 +99,29 @@ aoss_indices = [
|
||||||
'source': {'type': 'text'},
|
'source': {'type': 'text'},
|
||||||
'source_description': {'type': 'text'},
|
'source_description': {'type': 'text'},
|
||||||
'group_id': {'type': 'text'},
|
'group_id': {'type': 'text'},
|
||||||
'created_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
|
'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
||||||
'valid_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
|
'valid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'index_name': 'entity_edges_test',
|
'index_name': ENTITY_EDGE_INDEX_NAME,
|
||||||
'body': {
|
'body': {
|
||||||
|
'settings': {'index': {'knn': True}},
|
||||||
'mappings': {
|
'mappings': {
|
||||||
'properties': {
|
'properties': {
|
||||||
'uuid': {'type': 'keyword'},
|
'uuid': {'type': 'keyword'},
|
||||||
'name': {'type': 'text'},
|
'name': {'type': 'text'},
|
||||||
'fact': {'type': 'text'},
|
'fact': {'type': 'text'},
|
||||||
'group_id': {'type': 'text'},
|
'group_id': {'type': 'text'},
|
||||||
'created_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
|
'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
||||||
'valid_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
|
'valid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
||||||
'expired_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
|
'expired_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
||||||
'invalid_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
|
'invalid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
||||||
'fact_embedding': {
|
'fact_embedding': {
|
||||||
'type': 'knn_vector',
|
'type': 'knn_vector',
|
||||||
'dims': EMBEDDING_DIM,
|
'dimension': EMBEDDING_DIM,
|
||||||
'index': True,
|
|
||||||
'similarity': 'cosine',
|
|
||||||
'method': {
|
'method': {
|
||||||
'engine': 'faiss',
|
'engine': 'faiss',
|
||||||
'space_type': 'cosinesimil',
|
'space_type': 'cosinesimil',
|
||||||
|
|
@ -127,7 +130,7 @@ aoss_indices = [
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
@ -219,19 +222,36 @@ class GraphDriver(ABC):
|
||||||
client.indices.put_alias(index=physical_index_name, name=alias_name)
|
client.indices.put_alias(index=physical_index_name, name=alias_name)
|
||||||
|
|
||||||
# Allow some time for index creation
|
# Allow some time for index creation
|
||||||
await asyncio.sleep(60)
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
async def delete_aoss_indices(self):
|
async def delete_aoss_indices(self):
|
||||||
for index in aoss_indices:
|
client = self.aoss_client
|
||||||
index_name = index['index_name']
|
|
||||||
client = self.aoss_client
|
|
||||||
|
|
||||||
if not client:
|
if not client:
|
||||||
logger.warning('No OpenSearch client found')
|
logger.warning('No OpenSearch client found')
|
||||||
return
|
return
|
||||||
|
|
||||||
if client.indices.exists(index=index_name):
|
for entry in aoss_indices:
|
||||||
client.indices.delete(index=index_name)
|
alias_name = entry['index_name']
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Resolve alias → indices
|
||||||
|
alias_info = client.indices.get_alias(name=alias_name)
|
||||||
|
indices = list(alias_info.keys())
|
||||||
|
|
||||||
|
if not indices:
|
||||||
|
logger.info(f"No indices found for alias '{alias_name}'")
|
||||||
|
continue
|
||||||
|
|
||||||
|
for index in indices:
|
||||||
|
if client.indices.exists(index=index):
|
||||||
|
client.indices.delete(index=index)
|
||||||
|
logger.info(f"Deleted index '{index}' (alias: {alias_name})")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Index '{index}' not found for alias '{alias_name}'")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error deleting indices for alias '{alias_name}': {e}")
|
||||||
|
|
||||||
async def clear_aoss_indices(self):
|
async def clear_aoss_indices(self):
|
||||||
client = self.aoss_client
|
client = self.aoss_client
|
||||||
|
|
@ -277,7 +297,9 @@ class GraphDriver(ABC):
|
||||||
item[p] = d[p]
|
item[p] = d[p]
|
||||||
to_index.append(item)
|
to_index.append(item)
|
||||||
|
|
||||||
success, failed = helpers.bulk(client, to_index, stats_only=True)
|
success, failed = helpers.bulk(
|
||||||
|
client, to_index, stats_only=True, request_timeout=60
|
||||||
|
)
|
||||||
|
|
||||||
return success if failed == 0 else success
|
return success if failed == 0 else success
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -56,8 +56,9 @@ class Neo4jDriver(GraphDriver):
|
||||||
database: str = 'neo4j',
|
database: str = 'neo4j',
|
||||||
aoss_host: str | None = None,
|
aoss_host: str | None = None,
|
||||||
aoss_port: int | None = None,
|
aoss_port: int | None = None,
|
||||||
region: str | None = None,
|
aws_profile_name: str | None = None,
|
||||||
service: str | None = None,
|
aws_region: str | None = None,
|
||||||
|
aws_service: str | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.client = AsyncGraphDatabase.driver(
|
self.client = AsyncGraphDatabase.driver(
|
||||||
|
|
@ -69,9 +70,9 @@ class Neo4jDriver(GraphDriver):
|
||||||
self.aoss_client = None
|
self.aoss_client = None
|
||||||
if aoss_host and aoss_port and boto3 is not None:
|
if aoss_host and aoss_port and boto3 is not None:
|
||||||
try:
|
try:
|
||||||
region = region
|
region = aws_region
|
||||||
service = service
|
service = aws_service
|
||||||
credentials = boto3.Session(profile_name='zep-development').get_credentials()
|
credentials = boto3.Session(profile_name=aws_profile_name).get_credentials()
|
||||||
auth = AWSV4SignerAuth(credentials, region, service)
|
auth = AWSV4SignerAuth(credentials, region, service)
|
||||||
|
|
||||||
self.aoss_client = OpenSearch(
|
self.aoss_client = OpenSearch(
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ from uuid import uuid4
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import LiteralString
|
from typing_extensions import LiteralString
|
||||||
|
|
||||||
from graphiti_core.driver.driver import GraphDriver, GraphProvider
|
from graphiti_core.driver.driver import ENTITY_EDGE_INDEX_NAME, GraphDriver, GraphProvider
|
||||||
from graphiti_core.embedder import EmbedderClient
|
from graphiti_core.embedder import EmbedderClient
|
||||||
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
|
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
|
||||||
from graphiti_core.helpers import parse_db_date
|
from graphiti_core.helpers import parse_db_date
|
||||||
|
|
@ -79,7 +79,7 @@ class Edge(BaseModel, ABC):
|
||||||
|
|
||||||
if driver.aoss_client:
|
if driver.aoss_client:
|
||||||
await driver.aoss_client.delete(
|
await driver.aoss_client.delete(
|
||||||
index='entity_edges', id=self.uuid, routing=self.group_id
|
index=ENTITY_EDGE_INDEX_NAME, id=self.uuid, routing=self.group_id
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f'Deleted Edge: {self.uuid}')
|
logger.debug(f'Deleted Edge: {self.uuid}')
|
||||||
|
|
@ -115,7 +115,7 @@ class Edge(BaseModel, ABC):
|
||||||
|
|
||||||
if driver.aoss_client:
|
if driver.aoss_client:
|
||||||
await driver.aoss_client.delete_by_query(
|
await driver.aoss_client.delete_by_query(
|
||||||
index='entity_edges',
|
index=ENTITY_EDGE_INDEX_NAME,
|
||||||
body={'query': {'terms': {'uuid': uuids}}},
|
body={'query': {'terms': {'uuid': uuids}}},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -272,7 +272,7 @@ class EntityEdge(Edge):
|
||||||
'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
|
'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
|
||||||
'size': 1,
|
'size': 1,
|
||||||
},
|
},
|
||||||
index='entity_edges',
|
index=ENTITY_EDGE_INDEX_NAME,
|
||||||
routing=self.group_id,
|
routing=self.group_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -325,7 +325,7 @@ class EntityEdge(Edge):
|
||||||
edge_data.update(self.attributes or {})
|
edge_data.update(self.attributes or {})
|
||||||
|
|
||||||
if driver.aoss_client:
|
if driver.aoss_client:
|
||||||
driver.save_to_aoss('entity_edges', [edge_data]) # pyright: ignore reportAttributeAccessIssue
|
driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, [edge_data]) # pyright: ignore reportAttributeAccessIssue
|
||||||
|
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
get_entity_edge_save_query(driver.provider),
|
get_entity_edge_save_query(driver.provider),
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,14 @@ from uuid import uuid4
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import LiteralString
|
from typing_extensions import LiteralString
|
||||||
|
|
||||||
from graphiti_core.driver.driver import GraphDriver, GraphProvider
|
from graphiti_core.driver.driver import (
|
||||||
|
COMMUNITY_INDEX_NAME,
|
||||||
|
ENTITY_EDGE_INDEX_NAME,
|
||||||
|
ENTTITY_INDEX_NAME,
|
||||||
|
EPISODE_INDEX_NAME,
|
||||||
|
GraphDriver,
|
||||||
|
GraphProvider,
|
||||||
|
)
|
||||||
from graphiti_core.embedder import EmbedderClient
|
from graphiti_core.embedder import EmbedderClient
|
||||||
from graphiti_core.errors import NodeNotFoundError
|
from graphiti_core.errors import NodeNotFoundError
|
||||||
from graphiti_core.helpers import parse_db_date
|
from graphiti_core.helpers import parse_db_date
|
||||||
|
|
@ -110,7 +117,7 @@ class Node(BaseModel, ABC):
|
||||||
|
|
||||||
if driver.aoss_client:
|
if driver.aoss_client:
|
||||||
# Delete the node from OpenSearch indices
|
# Delete the node from OpenSearch indices
|
||||||
for index in ('episodes', 'entities', 'communities'):
|
for index in (EPISODE_INDEX_NAME, ENTTITY_INDEX_NAME, COMMUNITY_INDEX_NAME):
|
||||||
await driver.aoss_client.delete(
|
await driver.aoss_client.delete(
|
||||||
index=index, id=self.uuid, routing=self.group_id
|
index=index, id=self.uuid, routing=self.group_id
|
||||||
)
|
)
|
||||||
|
|
@ -119,7 +126,9 @@ class Node(BaseModel, ABC):
|
||||||
if edge_uuids:
|
if edge_uuids:
|
||||||
actions = []
|
actions = []
|
||||||
for eid in edge_uuids:
|
for eid in edge_uuids:
|
||||||
actions.append({'delete': {'_index': 'entity_edges', '_id': eid}})
|
actions.append(
|
||||||
|
{'delete': {'_index': ENTITY_EDGE_INDEX_NAME, '_id': eid}}
|
||||||
|
)
|
||||||
|
|
||||||
await driver.aoss_client.bulk(body=actions)
|
await driver.aoss_client.bulk(body=actions)
|
||||||
|
|
||||||
|
|
@ -187,25 +196,25 @@ class Node(BaseModel, ABC):
|
||||||
|
|
||||||
if driver.aoss_client:
|
if driver.aoss_client:
|
||||||
await driver.aoss_client.delete_by_query(
|
await driver.aoss_client.delete_by_query(
|
||||||
index='episodes',
|
index=EPISODE_INDEX_NAME,
|
||||||
body={'query': {'term': {'group_id': group_id}}},
|
body={'query': {'term': {'group_id': group_id}}},
|
||||||
routing=group_id,
|
routing=group_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
await driver.aoss_client.delete_by_query(
|
await driver.aoss_client.delete_by_query(
|
||||||
index='entities',
|
index=ENTTITY_INDEX_NAME,
|
||||||
body={'query': {'term': {'group_id': group_id}}},
|
body={'query': {'term': {'group_id': group_id}}},
|
||||||
routing=group_id,
|
routing=group_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
await driver.aoss_client.delete_by_query(
|
await driver.aoss_client.delete_by_query(
|
||||||
index='communities',
|
index=COMMUNITY_INDEX_NAME,
|
||||||
body={'query': {'term': {'group_id': group_id}}},
|
body={'query': {'term': {'group_id': group_id}}},
|
||||||
routing=group_id,
|
routing=group_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
await driver.aoss_client.delete_by_query(
|
await driver.aoss_client.delete_by_query(
|
||||||
index='entity_edges',
|
index=ENTITY_EDGE_INDEX_NAME,
|
||||||
body={'query': {'term': {'group_id': group_id}}},
|
body={'query': {'term': {'group_id': group_id}}},
|
||||||
routing=group_id,
|
routing=group_id,
|
||||||
)
|
)
|
||||||
|
|
@ -319,7 +328,7 @@ class Node(BaseModel, ABC):
|
||||||
)
|
)
|
||||||
|
|
||||||
if driver.aoss_client:
|
if driver.aoss_client:
|
||||||
for index in ('episodes', 'entities', 'communities'):
|
for index in (EPISODE_INDEX_NAME, ENTTITY_INDEX_NAME, COMMUNITY_INDEX_NAME):
|
||||||
await driver.aoss_client.delete_by_query(
|
await driver.aoss_client.delete_by_query(
|
||||||
index=index,
|
index=index,
|
||||||
body={'query': {'terms': {'uuid': uuids}}},
|
body={'query': {'terms': {'uuid': uuids}}},
|
||||||
|
|
@ -327,7 +336,8 @@ class Node(BaseModel, ABC):
|
||||||
|
|
||||||
if edge_uuids:
|
if edge_uuids:
|
||||||
actions = [
|
actions = [
|
||||||
{'delete': {'_index': 'entity_edges', '_id': eid}} for eid in edge_uuids
|
{'delete': {'_index': ENTITY_EDGE_INDEX_NAME, '_id': eid}}
|
||||||
|
for eid in edge_uuids
|
||||||
]
|
]
|
||||||
await driver.aoss_client.bulk(body=actions)
|
await driver.aoss_client.bulk(body=actions)
|
||||||
|
|
||||||
|
|
@ -509,7 +519,7 @@ class EntityNode(Node):
|
||||||
'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
|
'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
|
||||||
'size': 1,
|
'size': 1,
|
||||||
},
|
},
|
||||||
index='entities',
|
index=ENTTITY_INDEX_NAME,
|
||||||
routing=self.group_id,
|
routing=self.group_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -557,7 +567,7 @@ class EntityNode(Node):
|
||||||
labels = ':'.join(self.labels + ['Entity'])
|
labels = ':'.join(self.labels + ['Entity'])
|
||||||
|
|
||||||
if driver.aoss_client:
|
if driver.aoss_client:
|
||||||
driver.save_to_aoss('entities', [entity_data]) # pyright: ignore reportAttributeAccessIssue
|
driver.save_to_aoss(ENTTITY_INDEX_NAME, [entity_data]) # pyright: ignore reportAttributeAccessIssue
|
||||||
|
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
get_entity_node_save_query(driver.provider, labels, bool(driver.aoss_client)),
|
get_entity_node_save_query(driver.provider, labels, bool(driver.aoss_client)),
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,13 @@ import numpy as np
|
||||||
from numpy._typing import NDArray
|
from numpy._typing import NDArray
|
||||||
from typing_extensions import LiteralString
|
from typing_extensions import LiteralString
|
||||||
|
|
||||||
from graphiti_core.driver.driver import GraphDriver, GraphProvider
|
from graphiti_core.driver.driver import (
|
||||||
|
ENTITY_EDGE_INDEX_NAME,
|
||||||
|
ENTTITY_INDEX_NAME,
|
||||||
|
EPISODE_INDEX_NAME,
|
||||||
|
GraphDriver,
|
||||||
|
GraphProvider,
|
||||||
|
)
|
||||||
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
|
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
|
||||||
from graphiti_core.graph_queries import (
|
from graphiti_core.graph_queries import (
|
||||||
get_nodes_query,
|
get_nodes_query,
|
||||||
|
|
@ -209,11 +215,11 @@ async def edge_fulltext_search(
|
||||||
# Match the edge ids and return the values
|
# Match the edge ids and return the values
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $ids as id
|
UNWIND $ids as id
|
||||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||||
WHERE e.group_id IN $group_ids
|
WHERE e.group_id IN $group_ids
|
||||||
AND id(e)=id
|
AND id(e)=id
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
AND id(e)=id
|
AND id(e)=id
|
||||||
|
|
@ -249,16 +255,19 @@ async def edge_fulltext_search(
|
||||||
route = group_ids[0] if group_ids else None
|
route = group_ids[0] if group_ids else None
|
||||||
filters = build_aoss_edge_filters(group_ids or [], search_filter)
|
filters = build_aoss_edge_filters(group_ids or [], search_filter)
|
||||||
res = driver.aoss_client.search(
|
res = driver.aoss_client.search(
|
||||||
index='entity_edges',
|
index=ENTITY_EDGE_INDEX_NAME,
|
||||||
routing=route,
|
routing=route,
|
||||||
_source=['uuid'],
|
_source=['uuid'],
|
||||||
query={
|
body={
|
||||||
'bool': {
|
'query': {
|
||||||
'filter': filters,
|
'bool': {
|
||||||
'must': [{'match': {'fact': {'query': query, 'operator': 'or'}}}],
|
'filter': filters,
|
||||||
|
'must': [{'match': {'fact': {'query': query, 'operator': 'or'}}}],
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
if res['hits']['total']['value'] > 0:
|
if res['hits']['total']['value'] > 0:
|
||||||
input_uuids = {}
|
input_uuids = {}
|
||||||
for r in res['hits']['hits']:
|
for r in res['hits']['hits']:
|
||||||
|
|
@ -344,8 +353,8 @@ async def edge_similarity_search(
|
||||||
if driver.provider == GraphProvider.NEPTUNE:
|
if driver.provider == GraphProvider.NEPTUNE:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
|
RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
|
||||||
|
|
@ -407,16 +416,17 @@ async def edge_similarity_search(
|
||||||
route = group_ids[0] if group_ids else None
|
route = group_ids[0] if group_ids else None
|
||||||
filters = build_aoss_edge_filters(group_ids or [], search_filter)
|
filters = build_aoss_edge_filters(group_ids or [], search_filter)
|
||||||
res = driver.aoss_client.search(
|
res = driver.aoss_client.search(
|
||||||
index='entity_edges',
|
index=ENTITY_EDGE_INDEX_NAME,
|
||||||
routing=route,
|
routing=route,
|
||||||
_source=['uuid'],
|
_source=['uuid'],
|
||||||
knn={
|
size=limit,
|
||||||
'field': 'fact_embedding',
|
body={
|
||||||
'query_vector': search_vector,
|
'query': {
|
||||||
'k': limit,
|
'knn': {
|
||||||
'num_candidates': 1000,
|
'fact_embedding': {'vector': list(map(float, search_vector)), 'k': limit}
|
||||||
|
}
|
||||||
|
}
|
||||||
},
|
},
|
||||||
query={'bool': {'filter': filters}},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if res['hits']['total']['value'] > 0:
|
if res['hits']['total']['value'] > 0:
|
||||||
|
|
@ -428,6 +438,7 @@ async def edge_similarity_search(
|
||||||
entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys()))
|
entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys()))
|
||||||
entity_edges.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
|
entity_edges.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
|
||||||
return entity_edges
|
return entity_edges
|
||||||
|
return []
|
||||||
|
|
||||||
else:
|
else:
|
||||||
query = (
|
query = (
|
||||||
|
|
@ -622,11 +633,11 @@ async def node_fulltext_search(
|
||||||
# Match the edge ides and return the values
|
# Match the edge ides and return the values
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $ids as i
|
UNWIND $ids as i
|
||||||
MATCH (n:Entity)
|
MATCH (n:Entity)
|
||||||
WHERE n.uuid=i.id
|
WHERE n.uuid=i.id
|
||||||
RETURN
|
RETURN
|
||||||
"""
|
"""
|
||||||
+ get_entity_node_return_query(driver.provider)
|
+ get_entity_node_return_query(driver.provider)
|
||||||
+ """
|
+ """
|
||||||
ORDER BY i.score DESC
|
ORDER BY i.score DESC
|
||||||
|
|
@ -647,24 +658,26 @@ async def node_fulltext_search(
|
||||||
route = group_ids[0] if group_ids else None
|
route = group_ids[0] if group_ids else None
|
||||||
filters = build_aoss_node_filters(group_ids or [], search_filter)
|
filters = build_aoss_node_filters(group_ids or [], search_filter)
|
||||||
res = driver.aoss_client.search(
|
res = driver.aoss_client.search(
|
||||||
'entities',
|
index=ENTTITY_INDEX_NAME,
|
||||||
routing=route,
|
routing=route,
|
||||||
_source=['uuid'],
|
_source=['uuid'],
|
||||||
query={
|
size=limit,
|
||||||
'bool': {
|
body={
|
||||||
'filter': filters,
|
'query': {
|
||||||
'must': [
|
'bool': {
|
||||||
{
|
'filter': filters,
|
||||||
'multi_match': {
|
'must': [
|
||||||
'query': query,
|
{
|
||||||
'field': ['name', 'summary'],
|
'multi_match': {
|
||||||
'operator': 'or',
|
'query': query,
|
||||||
|
'fields': ['name', 'summary'], # ✅ fixed key
|
||||||
|
'operator': 'or',
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
],
|
||||||
],
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
limit=limit,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if res['hits']['total']['value'] > 0:
|
if res['hits']['total']['value'] > 0:
|
||||||
|
|
@ -734,8 +747,8 @@ async def node_similarity_search(
|
||||||
if driver.provider == GraphProvider.NEPTUNE:
|
if driver.provider == GraphProvider.NEPTUNE:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity)
|
MATCH (n:Entity)
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
||||||
|
|
@ -764,11 +777,11 @@ async def node_similarity_search(
|
||||||
# Match the edge ides and return the values
|
# Match the edge ides and return the values
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $ids as i
|
UNWIND $ids as i
|
||||||
MATCH (n:Entity)
|
MATCH (n:Entity)
|
||||||
WHERE id(n)=i.id
|
WHERE id(n)=i.id
|
||||||
RETURN
|
RETURN
|
||||||
"""
|
"""
|
||||||
+ get_entity_node_return_query(driver.provider)
|
+ get_entity_node_return_query(driver.provider)
|
||||||
+ """
|
+ """
|
||||||
ORDER BY i.score DESC
|
ORDER BY i.score DESC
|
||||||
|
|
@ -790,16 +803,17 @@ async def node_similarity_search(
|
||||||
route = group_ids[0] if group_ids else None
|
route = group_ids[0] if group_ids else None
|
||||||
filters = build_aoss_node_filters(group_ids or [], search_filter)
|
filters = build_aoss_node_filters(group_ids or [], search_filter)
|
||||||
res = driver.aoss_client.search(
|
res = driver.aoss_client.search(
|
||||||
index='entities',
|
index=ENTTITY_INDEX_NAME,
|
||||||
routing=route,
|
routing=route,
|
||||||
_source=['uuid'],
|
_source=['uuid'],
|
||||||
knn={
|
size=limit,
|
||||||
'field': 'fact_embedding',
|
body={
|
||||||
'query_vector': search_vector,
|
'query': {
|
||||||
'k': limit,
|
'knn': {
|
||||||
'num_candidates': 1000,
|
'name_embedding': {'vector': list(map(float, search_vector)), 'k': limit}
|
||||||
|
}
|
||||||
|
}
|
||||||
},
|
},
|
||||||
query={'bool': {'filter': filters}},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if res['hits']['total']['value'] > 0:
|
if res['hits']['total']['value'] > 0:
|
||||||
|
|
@ -811,11 +825,12 @@ async def node_similarity_search(
|
||||||
entity_nodes = await EntityNode.get_by_uuids(driver, list(input_uuids.keys()))
|
entity_nodes = await EntityNode.get_by_uuids(driver, list(input_uuids.keys()))
|
||||||
entity_nodes.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
|
entity_nodes.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
|
||||||
return entity_nodes
|
return entity_nodes
|
||||||
|
return []
|
||||||
else:
|
else:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity)
|
MATCH (n:Entity)
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH n, """
|
WITH n, """
|
||||||
|
|
@ -989,7 +1004,7 @@ async def episode_fulltext_search(
|
||||||
elif driver.aoss_client:
|
elif driver.aoss_client:
|
||||||
route = group_ids[0] if group_ids else None
|
route = group_ids[0] if group_ids else None
|
||||||
res = driver.aoss_client.search(
|
res = driver.aoss_client.search(
|
||||||
'episodes',
|
EPISODE_INDEX_NAME,
|
||||||
routing=route,
|
routing=route,
|
||||||
_source=['uuid'],
|
_source=['uuid'],
|
||||||
query={
|
query={
|
||||||
|
|
@ -1147,8 +1162,8 @@ async def community_similarity_search(
|
||||||
if driver.provider == GraphProvider.NEPTUNE:
|
if driver.provider == GraphProvider.NEPTUNE:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
MATCH (n:Community)
|
MATCH (n:Community)
|
||||||
"""
|
"""
|
||||||
+ group_filter_query
|
+ group_filter_query
|
||||||
+ """
|
+ """
|
||||||
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
||||||
|
|
@ -1207,8 +1222,8 @@ async def community_similarity_search(
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
MATCH (c:Community)
|
MATCH (c:Community)
|
||||||
"""
|
"""
|
||||||
+ group_filter_query
|
+ group_filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH c,
|
WITH c,
|
||||||
|
|
@ -1350,9 +1365,9 @@ async def get_relevant_nodes(
|
||||||
# FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
|
# FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $nodes AS node
|
UNWIND $nodes AS node
|
||||||
MATCH (n:Entity {group_id: $group_id})
|
MATCH (n:Entity {group_id: $group_id})
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH node, n, """
|
WITH node, n, """
|
||||||
|
|
@ -1397,9 +1412,9 @@ async def get_relevant_nodes(
|
||||||
else:
|
else:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $nodes AS node
|
UNWIND $nodes AS node
|
||||||
MATCH (n:Entity {group_id: $group_id})
|
MATCH (n:Entity {group_id: $group_id})
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH node, n, """
|
WITH node, n, """
|
||||||
|
|
@ -1488,9 +1503,9 @@ async def get_relevant_edges(
|
||||||
if driver.provider == GraphProvider.NEPTUNE:
|
if driver.provider == GraphProvider.NEPTUNE:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $edges AS edge
|
UNWIND $edges AS edge
|
||||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH e, edge
|
WITH e, edge
|
||||||
|
|
@ -1560,9 +1575,9 @@ async def get_relevant_edges(
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $edges AS edge
|
UNWIND $edges AS edge
|
||||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
|
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH e, edge, n, m, """
|
WITH e, edge, n, m, """
|
||||||
|
|
@ -1599,9 +1614,9 @@ async def get_relevant_edges(
|
||||||
# First get edge candidates
|
# First get edge candidates
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $edges AS edge
|
UNWIND $edges AS edge
|
||||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
RETURN
|
RETURN
|
||||||
|
|
@ -1647,9 +1662,9 @@ async def get_relevant_edges(
|
||||||
else:
|
else:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $edges AS edge
|
UNWIND $edges AS edge
|
||||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH e, edge, """
|
WITH e, edge, """
|
||||||
|
|
@ -1722,10 +1737,10 @@ async def get_edge_invalidation_candidates(
|
||||||
if driver.provider == GraphProvider.NEPTUNE:
|
if driver.provider == GraphProvider.NEPTUNE:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $edges AS edge
|
UNWIND $edges AS edge
|
||||||
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
||||||
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH e, edge
|
WITH e, edge
|
||||||
|
|
@ -1795,10 +1810,10 @@ async def get_edge_invalidation_candidates(
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $edges AS edge
|
UNWIND $edges AS edge
|
||||||
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
|
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
|
||||||
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
|
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH edge, e, n, m, """
|
WITH edge, e, n, m, """
|
||||||
|
|
@ -1834,10 +1849,10 @@ async def get_edge_invalidation_candidates(
|
||||||
else:
|
else:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $edges AS edge
|
UNWIND $edges AS edge
|
||||||
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
||||||
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH edge, e, """
|
WITH edge, e, """
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,14 @@ import numpy as np
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Any
|
from typing_extensions import Any
|
||||||
|
|
||||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
from graphiti_core.driver.driver import (
|
||||||
|
ENTITY_EDGE_INDEX_NAME,
|
||||||
|
ENTTITY_INDEX_NAME,
|
||||||
|
EPISODE_INDEX_NAME,
|
||||||
|
GraphDriver,
|
||||||
|
GraphDriverSession,
|
||||||
|
GraphProvider,
|
||||||
|
)
|
||||||
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings
|
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings
|
||||||
from graphiti_core.embedder import EmbedderClient
|
from graphiti_core.embedder import EmbedderClient
|
||||||
from graphiti_core.graphiti_types import GraphitiClients
|
from graphiti_core.graphiti_types import GraphitiClients
|
||||||
|
|
@ -203,9 +210,9 @@ async def add_nodes_and_edges_bulk_tx(
|
||||||
)
|
)
|
||||||
|
|
||||||
if driver.aoss_client:
|
if driver.aoss_client:
|
||||||
driver.save_to_aoss('episodes', episodes)
|
driver.save_to_aoss(EPISODE_INDEX_NAME, episodes)
|
||||||
driver.save_to_aoss('entities', nodes)
|
driver.save_to_aoss(ENTTITY_INDEX_NAME, nodes)
|
||||||
driver.save_to_aoss('entity_edges', edges)
|
driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, edges)
|
||||||
|
|
||||||
|
|
||||||
async def extract_nodes_and_edges_bulk(
|
async def extract_nodes_and_edges_bulk(
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue