OpenSearch Integration for Neo4j (#896)
* move aoss to driver * add indexes * don't save vectors to neo4j with aoss * load embeddings from aoss * add group_id routing * add search filters and similarity search * neptune regression update * update neptune for regression purposes * update index creation with aliasing * regression tested * update version * edits * claude suggestions * cleanup * updates * add embedding dim env var * use cosine sim * updates * updates * remove unused imports * update
This commit is contained in:
parent
a3479758d5
commit
0884cc00e5
16 changed files with 634 additions and 160 deletions
|
|
@ -14,15 +14,30 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Coroutine
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from graphiti_core.embedder.client import EMBEDDING_DIM
|
||||
|
||||
try:
|
||||
from opensearchpy import OpenSearch, helpers
|
||||
|
||||
_HAS_OPENSEARCH = True
|
||||
except ImportError:
|
||||
OpenSearch = None
|
||||
helpers = None
|
||||
_HAS_OPENSEARCH = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_SIZE = 10
|
||||
|
||||
|
||||
class GraphProvider(Enum):
|
||||
NEO4J = 'neo4j'
|
||||
|
|
@ -31,6 +46,93 @@ class GraphProvider(Enum):
|
|||
NEPTUNE = 'neptune'
|
||||
|
||||
|
||||
aoss_indices = [
|
||||
{
|
||||
'index_name': 'entities',
|
||||
'body': {
|
||||
'mappings': {
|
||||
'properties': {
|
||||
'uuid': {'type': 'keyword'},
|
||||
'name': {'type': 'text'},
|
||||
'summary': {'type': 'text'},
|
||||
'group_id': {'type': 'text'},
|
||||
'created_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
|
||||
'name_embedding': {
|
||||
'type': 'knn_vector',
|
||||
'dims': EMBEDDING_DIM,
|
||||
'index': True,
|
||||
'similarity': 'cosine',
|
||||
'method': {
|
||||
'engine': 'faiss',
|
||||
'space_type': 'cosinesimil',
|
||||
'name': 'hnsw',
|
||||
'parameters': {'ef_construction': 128, 'm': 16},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
'index_name': 'communities',
|
||||
'body': {
|
||||
'mappings': {
|
||||
'properties': {
|
||||
'uuid': {'type': 'keyword'},
|
||||
'name': {'type': 'text'},
|
||||
'group_id': {'type': 'text'},
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
'index_name': 'episodes',
|
||||
'body': {
|
||||
'mappings': {
|
||||
'properties': {
|
||||
'uuid': {'type': 'keyword'},
|
||||
'content': {'type': 'text'},
|
||||
'source': {'type': 'text'},
|
||||
'source_description': {'type': 'text'},
|
||||
'group_id': {'type': 'text'},
|
||||
'created_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
|
||||
'valid_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
'index_name': 'entity_edges',
|
||||
'body': {
|
||||
'mappings': {
|
||||
'properties': {
|
||||
'uuid': {'type': 'keyword'},
|
||||
'name': {'type': 'text'},
|
||||
'fact': {'type': 'text'},
|
||||
'group_id': {'type': 'text'},
|
||||
'created_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
|
||||
'valid_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
|
||||
'expired_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
|
||||
'invalid_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
|
||||
'fact_embedding': {
|
||||
'type': 'knn_vector',
|
||||
'dims': EMBEDDING_DIM,
|
||||
'index': True,
|
||||
'similarity': 'cosine',
|
||||
'method': {
|
||||
'engine': 'faiss',
|
||||
'space_type': 'cosinesimil',
|
||||
'name': 'hnsw',
|
||||
'parameters': {'ef_construction': 128, 'm': 16},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class GraphDriverSession(ABC):
|
||||
provider: GraphProvider
|
||||
|
||||
|
|
@ -61,6 +163,7 @@ class GraphDriver(ABC):
|
|||
'' # Neo4j (default) syntax does not require a prefix for fulltext queries
|
||||
)
|
||||
_database: str
|
||||
aoss_client: OpenSearch | None # type: ignore
|
||||
|
||||
@abstractmethod
|
||||
def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
|
||||
|
|
@ -87,3 +190,70 @@ class GraphDriver(ABC):
|
|||
cloned._database = database
|
||||
|
||||
return cloned
|
||||
|
||||
async def delete_all_indexes_impl(self) -> Coroutine[Any, Any, Any]:
|
||||
# No matter what happens above, always return True
|
||||
return self.delete_aoss_indices()
|
||||
|
||||
async def create_aoss_indices(self):
|
||||
client = self.aoss_client
|
||||
if not client:
|
||||
logger.warning('No OpenSearch client found')
|
||||
return
|
||||
|
||||
for index in aoss_indices:
|
||||
alias_name = index['index_name']
|
||||
|
||||
# If alias already exists, skip (idempotent behavior)
|
||||
if client.indices.exists_alias(name=alias_name):
|
||||
continue
|
||||
|
||||
# Build a physical index name with timestamp
|
||||
ts_suffix = datetime.utcnow().strftime('%Y%m%d%H%M%S')
|
||||
physical_index_name = f'{alias_name}_{ts_suffix}'
|
||||
|
||||
# Create the index
|
||||
client.indices.create(index=physical_index_name, body=index['body'])
|
||||
|
||||
# Point alias to it
|
||||
client.indices.put_alias(index=physical_index_name, name=alias_name)
|
||||
|
||||
# Allow some time for index creation
|
||||
await asyncio.sleep(60)
|
||||
|
||||
async def delete_aoss_indices(self):
|
||||
for index in aoss_indices:
|
||||
index_name = index['index_name']
|
||||
client = self.aoss_client
|
||||
|
||||
if not client:
|
||||
logger.warning('No OpenSearch client found')
|
||||
return
|
||||
|
||||
if client.indices.exists(index=index_name):
|
||||
client.indices.delete(index=index_name)
|
||||
|
||||
def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
||||
client = self.aoss_client
|
||||
if not client or not helpers:
|
||||
logger.warning('No OpenSearch client found')
|
||||
return 0
|
||||
|
||||
for index in aoss_indices:
|
||||
if name.lower() == index['index_name']:
|
||||
to_index = []
|
||||
for d in data:
|
||||
item = {
|
||||
'_index': name,
|
||||
'_routing': d.get('group_id'), # shard routing
|
||||
}
|
||||
for p in index['body']['mappings']['properties']:
|
||||
if p in d: # protect against missing fields
|
||||
item[p] = d[p]
|
||||
to_index.append(item)
|
||||
|
||||
success, failed = helpers.bulk(client, to_index, stats_only=True)
|
||||
|
||||
return success if failed == 0 else success
|
||||
|
||||
return 0
|
||||
|
|
|
|||
|
|
@ -74,6 +74,7 @@ class FalkorDriverSession(GraphDriverSession):
|
|||
|
||||
class FalkorDriver(GraphDriver):
|
||||
provider = GraphProvider.FALKORDB
|
||||
aoss_client: None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -92,6 +92,7 @@ SCHEMA_QUERIES = """
|
|||
|
||||
class KuzuDriver(GraphDriver):
|
||||
provider: GraphProvider = GraphProvider.KUZU
|
||||
aoss_client: None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -22,14 +22,35 @@ from neo4j import AsyncGraphDatabase, EagerResult
|
|||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
||||
from graphiti_core.helpers import semaphore_gather
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import boto3
|
||||
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection
|
||||
|
||||
_HAS_OPENSEARCH = True
|
||||
except ImportError:
|
||||
boto3 = None
|
||||
OpenSearch = None
|
||||
Urllib3AWSV4SignerAuth = None
|
||||
Urllib3HttpConnection = None
|
||||
_HAS_OPENSEARCH = False
|
||||
|
||||
|
||||
class Neo4jDriver(GraphDriver):
|
||||
provider = GraphProvider.NEO4J
|
||||
|
||||
def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'neo4j'):
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
user: str | None,
|
||||
password: str | None,
|
||||
database: str = 'neo4j',
|
||||
aoss_host: str | None = None,
|
||||
aoss_port: int | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.client = AsyncGraphDatabase.driver(
|
||||
uri=uri,
|
||||
|
|
@ -37,6 +58,24 @@ class Neo4jDriver(GraphDriver):
|
|||
)
|
||||
self._database = database
|
||||
|
||||
self.aoss_client = None
|
||||
if aoss_host and aoss_port and boto3 is not None:
|
||||
try:
|
||||
session = boto3.Session()
|
||||
self.aoss_client = OpenSearch( # type: ignore
|
||||
hosts=[{'host': aoss_host, 'port': aoss_port}],
|
||||
http_auth=Urllib3AWSV4SignerAuth( # type: ignore
|
||||
session.get_credentials(), session.region_name, 'aoss'
|
||||
),
|
||||
use_ssl=True,
|
||||
verify_certs=True,
|
||||
connection_class=Urllib3HttpConnection,
|
||||
pool_maxsize=20,
|
||||
) # type: ignore
|
||||
except Exception as e:
|
||||
logger.warning(f'Failed to initialize OpenSearch client: {e}')
|
||||
self.aoss_client = None
|
||||
|
||||
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult:
|
||||
# Check if database_ is provided in kwargs.
|
||||
# If not populated, set the value to retain backwards compatibility
|
||||
|
|
@ -60,7 +99,14 @@ class Neo4jDriver(GraphDriver):
|
|||
async def close(self) -> None:
|
||||
return await self.client.close()
|
||||
|
||||
def delete_all_indexes(self) -> Coroutine[Any, Any, EagerResult]:
|
||||
def delete_all_indexes(self) -> Coroutine:
|
||||
if self.aoss_client:
|
||||
return semaphore_gather(
|
||||
self.client.execute_query(
|
||||
'CALL db.indexes() YIELD name DROP INDEX name',
|
||||
),
|
||||
self.delete_aoss_indices(),
|
||||
)
|
||||
return self.client.execute_query(
|
||||
'CALL db.indexes() YIELD name DROP INDEX name',
|
||||
)
|
||||
|
|
|
|||
|
|
@ -22,16 +22,21 @@ from typing import Any
|
|||
|
||||
import boto3
|
||||
from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph
|
||||
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
|
||||
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
||||
from graphiti_core.driver.driver import (
|
||||
DEFAULT_SIZE,
|
||||
GraphDriver,
|
||||
GraphDriverSession,
|
||||
GraphProvider,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_SIZE = 10
|
||||
|
||||
aoss_indices = [
|
||||
neptune_aoss_indices = [
|
||||
{
|
||||
'index_name': 'node_name_and_summary',
|
||||
'alias_name': 'entities',
|
||||
'body': {
|
||||
'mappings': {
|
||||
'properties': {
|
||||
|
|
@ -49,6 +54,7 @@ aoss_indices = [
|
|||
},
|
||||
{
|
||||
'index_name': 'community_name',
|
||||
'alias_name': 'communities',
|
||||
'body': {
|
||||
'mappings': {
|
||||
'properties': {
|
||||
|
|
@ -65,6 +71,7 @@ aoss_indices = [
|
|||
},
|
||||
{
|
||||
'index_name': 'episode_content',
|
||||
'alias_name': 'episodes',
|
||||
'body': {
|
||||
'mappings': {
|
||||
'properties': {
|
||||
|
|
@ -88,6 +95,7 @@ aoss_indices = [
|
|||
},
|
||||
{
|
||||
'index_name': 'edge_name_and_fact',
|
||||
'alias_name': 'facts',
|
||||
'body': {
|
||||
'mappings': {
|
||||
'properties': {
|
||||
|
|
@ -220,54 +228,27 @@ class NeptuneDriver(GraphDriver):
|
|||
async def _delete_all_data(self) -> Any:
|
||||
return await self.execute_query('MATCH (n) DETACH DELETE n')
|
||||
|
||||
def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
|
||||
return self.delete_all_indexes_impl()
|
||||
|
||||
async def delete_all_indexes_impl(self) -> Coroutine[Any, Any, Any]:
|
||||
# No matter what happens above, always return True
|
||||
return self.delete_aoss_indices()
|
||||
|
||||
async def create_aoss_indices(self):
|
||||
for index in aoss_indices:
|
||||
for index in neptune_aoss_indices:
|
||||
index_name = index['index_name']
|
||||
client = self.aoss_client
|
||||
if not client:
|
||||
raise ValueError(
|
||||
'You must provide an AOSS endpoint to create an OpenSearch driver.'
|
||||
)
|
||||
if not client.indices.exists(index=index_name):
|
||||
client.indices.create(index=index_name, body=index['body'])
|
||||
|
||||
alias_name = index.get('alias_name', index_name)
|
||||
|
||||
if not client.indices.exists_alias(name=alias_name, index=index_name):
|
||||
client.indices.put_alias(index=index_name, name=alias_name)
|
||||
|
||||
# Sleep for 1 minute to let the index creation complete
|
||||
await asyncio.sleep(60)
|
||||
|
||||
async def delete_aoss_indices(self):
|
||||
for index in aoss_indices:
|
||||
index_name = index['index_name']
|
||||
client = self.aoss_client
|
||||
if client.indices.exists(index=index_name):
|
||||
client.indices.delete(index=index_name)
|
||||
|
||||
def run_aoss_query(self, name: str, query_text: str, limit: int = 10) -> dict[str, Any]:
|
||||
for index in aoss_indices:
|
||||
if name.lower() == index['index_name']:
|
||||
index['query']['query']['multi_match']['query'] = query_text
|
||||
query = {'size': limit, 'query': index['query']}
|
||||
resp = self.aoss_client.search(body=query['query'], index=index['index_name'])
|
||||
return resp
|
||||
return {}
|
||||
|
||||
def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
||||
for index in aoss_indices:
|
||||
if name.lower() == index['index_name']:
|
||||
to_index = []
|
||||
for d in data:
|
||||
item = {'_index': name}
|
||||
for p in index['body']['mappings']['properties']:
|
||||
item[p] = d[p]
|
||||
to_index.append(item)
|
||||
success, failed = helpers.bulk(self.aoss_client, to_index, stats_only=True)
|
||||
if failed > 0:
|
||||
return success
|
||||
else:
|
||||
return 0
|
||||
|
||||
return 0
|
||||
def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
|
||||
return self.delete_all_indexes_impl()
|
||||
|
||||
|
||||
class NeptuneDriverSession(GraphDriverSession):
|
||||
|
|
|
|||
|
|
@ -255,6 +255,21 @@ class EntityEdge(Edge):
|
|||
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
||||
RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
|
||||
"""
|
||||
elif driver.aoss_client:
|
||||
resp = driver.aoss_client.search(
|
||||
body={
|
||||
'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
|
||||
'size': 1,
|
||||
},
|
||||
index='entity_edges',
|
||||
routing=self.group_id,
|
||||
)
|
||||
|
||||
if resp['hits']['hits']:
|
||||
self.fact_embedding = resp['hits']['hits'][0]['_source']['fact_embedding']
|
||||
return
|
||||
else:
|
||||
raise EdgeNotFoundError(self.uuid)
|
||||
|
||||
if driver.provider == GraphProvider.KUZU:
|
||||
query = """
|
||||
|
|
@ -292,14 +307,14 @@ class EntityEdge(Edge):
|
|||
if driver.provider == GraphProvider.KUZU:
|
||||
edge_data['attributes'] = json.dumps(self.attributes)
|
||||
result = await driver.execute_query(
|
||||
get_entity_edge_save_query(driver.provider),
|
||||
get_entity_edge_save_query(driver.provider, has_aoss=bool(driver.aoss_client)),
|
||||
**edge_data,
|
||||
)
|
||||
else:
|
||||
edge_data.update(self.attributes or {})
|
||||
|
||||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
driver.save_to_aoss('edge_name_and_fact', [edge_data]) # pyright: ignore reportAttributeAccessIssue
|
||||
if driver.aoss_client:
|
||||
driver.save_to_aoss('entity_edges', [edge_data]) # pyright: ignore reportAttributeAccessIssue
|
||||
|
||||
result = await driver.execute_query(
|
||||
get_entity_edge_save_query(driver.provider),
|
||||
|
|
|
|||
|
|
@ -14,12 +14,13 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
EMBEDDING_DIM = 1024
|
||||
EMBEDDING_DIM = int(os.getenv('EMBEDDING_DIM', 1024))
|
||||
|
||||
|
||||
class EmbedderConfig(BaseModel):
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ EPISODIC_EDGE_RETURN = """
|
|||
"""
|
||||
|
||||
|
||||
def get_entity_edge_save_query(provider: GraphProvider) -> str:
|
||||
def get_entity_edge_save_query(provider: GraphProvider, has_aoss: bool = False) -> str:
|
||||
match provider:
|
||||
case GraphProvider.FALKORDB:
|
||||
return """
|
||||
|
|
@ -99,17 +99,28 @@ def get_entity_edge_save_query(provider: GraphProvider) -> str:
|
|||
RETURN e.uuid AS uuid
|
||||
"""
|
||||
case _: # Neo4j
|
||||
return """
|
||||
MATCH (source:Entity {uuid: $edge_data.source_uuid})
|
||||
MATCH (target:Entity {uuid: $edge_data.target_uuid})
|
||||
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
|
||||
SET e = $edge_data
|
||||
WITH e CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", $edge_data.fact_embedding)
|
||||
save_embedding_query = (
|
||||
"""WITH e CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", $edge_data.fact_embedding)"""
|
||||
if not has_aoss
|
||||
else ''
|
||||
)
|
||||
return (
|
||||
(
|
||||
"""
|
||||
MATCH (source:Entity {uuid: $edge_data.source_uuid})
|
||||
MATCH (target:Entity {uuid: $edge_data.target_uuid})
|
||||
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
|
||||
SET e = $edge_data
|
||||
"""
|
||||
+ save_embedding_query
|
||||
)
|
||||
+ """
|
||||
RETURN e.uuid AS uuid
|
||||
"""
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str:
|
||||
def get_entity_edge_save_bulk_query(provider: GraphProvider, has_aoss: bool = False) -> str:
|
||||
match provider:
|
||||
case GraphProvider.FALKORDB:
|
||||
return """
|
||||
|
|
@ -152,15 +163,24 @@ def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str:
|
|||
RETURN e.uuid AS uuid
|
||||
"""
|
||||
case _:
|
||||
return """
|
||||
UNWIND $entity_edges AS edge
|
||||
MATCH (source:Entity {uuid: edge.source_node_uuid})
|
||||
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
||||
MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target)
|
||||
SET e = edge
|
||||
WITH e, edge CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", edge.fact_embedding)
|
||||
save_embedding_query = (
|
||||
'WITH e, edge CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", edge.fact_embedding)'
|
||||
if not has_aoss
|
||||
else ''
|
||||
)
|
||||
return (
|
||||
"""
|
||||
UNWIND $entity_edges AS edge
|
||||
MATCH (source:Entity {uuid: edge.source_node_uuid})
|
||||
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
||||
MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target)
|
||||
SET e = edge
|
||||
"""
|
||||
+ save_embedding_query
|
||||
+ """
|
||||
RETURN edge.uuid AS uuid
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def get_entity_edge_return_query(provider: GraphProvider) -> str:
|
||||
|
|
|
|||
|
|
@ -126,7 +126,7 @@ EPISODIC_NODE_RETURN_NEPTUNE = """
|
|||
"""
|
||||
|
||||
|
||||
def get_entity_node_save_query(provider: GraphProvider, labels: str) -> str:
|
||||
def get_entity_node_save_query(provider: GraphProvider, labels: str, has_aoss: bool = False) -> str:
|
||||
match provider:
|
||||
case GraphProvider.FALKORDB:
|
||||
return f"""
|
||||
|
|
@ -161,16 +161,27 @@ def get_entity_node_save_query(provider: GraphProvider, labels: str) -> str:
|
|||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case _:
|
||||
return f"""
|
||||
save_embedding_query = (
|
||||
'WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding)'
|
||||
if not has_aoss
|
||||
else ''
|
||||
)
|
||||
return (
|
||||
f"""
|
||||
MERGE (n:Entity {{uuid: $entity_data.uuid}})
|
||||
SET n:{labels}
|
||||
SET n = $entity_data
|
||||
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding)
|
||||
"""
|
||||
+ save_embedding_query
|
||||
+ """
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict]) -> str | Any:
|
||||
def get_entity_node_save_bulk_query(
|
||||
provider: GraphProvider, nodes: list[dict], has_aoss: bool = False
|
||||
) -> str | Any:
|
||||
match provider:
|
||||
case GraphProvider.FALKORDB:
|
||||
queries = []
|
||||
|
|
@ -222,14 +233,23 @@ def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict])
|
|||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case _: # Neo4j
|
||||
return """
|
||||
UNWIND $nodes AS node
|
||||
MERGE (n:Entity {uuid: node.uuid})
|
||||
SET n:$(node.labels)
|
||||
SET n = node
|
||||
WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)
|
||||
save_embedding_query = (
|
||||
'WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)'
|
||||
if not has_aoss
|
||||
else ''
|
||||
)
|
||||
return (
|
||||
"""
|
||||
UNWIND $nodes AS node
|
||||
MERGE (n:Entity {uuid: node.uuid})
|
||||
SET n:$(node.labels)
|
||||
SET n = node
|
||||
"""
|
||||
+ save_embedding_query
|
||||
+ """
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def get_entity_node_return_query(provider: GraphProvider) -> str:
|
||||
|
|
|
|||
|
|
@ -273,20 +273,6 @@ class EpisodicNode(Node):
|
|||
)
|
||||
|
||||
async def save(self, driver: GraphDriver):
|
||||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
|
||||
'episode_content',
|
||||
[
|
||||
{
|
||||
'uuid': self.uuid,
|
||||
'group_id': self.group_id,
|
||||
'source': self.source.value,
|
||||
'content': self.content,
|
||||
'source_description': self.source_description,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
episode_args = {
|
||||
'uuid': self.uuid,
|
||||
'name': self.name,
|
||||
|
|
@ -299,6 +285,12 @@ class EpisodicNode(Node):
|
|||
'source': self.source.value,
|
||||
}
|
||||
|
||||
if driver.aoss_client:
|
||||
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
|
||||
'episodes',
|
||||
[episode_args],
|
||||
)
|
||||
|
||||
result = await driver.execute_query(
|
||||
get_episode_node_save_query(driver.provider), **episode_args
|
||||
)
|
||||
|
|
@ -433,6 +425,22 @@ class EntityNode(Node):
|
|||
MATCH (n:Entity {uuid: $uuid})
|
||||
RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding
|
||||
"""
|
||||
elif driver.aoss_client:
|
||||
resp = driver.aoss_client.search(
|
||||
body={
|
||||
'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
|
||||
'size': 1,
|
||||
},
|
||||
index='entities',
|
||||
routing=self.group_id,
|
||||
)
|
||||
|
||||
if resp['hits']['hits']:
|
||||
self.name_embedding = resp['hits']['hits'][0]['_source']['name_embedding']
|
||||
return
|
||||
else:
|
||||
raise NodeNotFoundError(self.uuid)
|
||||
|
||||
else:
|
||||
query: LiteralString = """
|
||||
MATCH (n:Entity {uuid: $uuid})
|
||||
|
|
@ -470,11 +478,11 @@ class EntityNode(Node):
|
|||
entity_data.update(self.attributes or {})
|
||||
labels = ':'.join(self.labels + ['Entity'])
|
||||
|
||||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
driver.save_to_aoss('node_name_and_summary', [entity_data]) # pyright: ignore reportAttributeAccessIssue
|
||||
if driver.aoss_client:
|
||||
driver.save_to_aoss('entities', [entity_data]) # pyright: ignore reportAttributeAccessIssue
|
||||
|
||||
result = await driver.execute_query(
|
||||
get_entity_node_save_query(driver.provider, labels),
|
||||
get_entity_node_save_query(driver.provider, labels, bool(driver.aoss_client)),
|
||||
entity_data=entity_data,
|
||||
)
|
||||
|
||||
|
|
@ -570,7 +578,7 @@ class CommunityNode(Node):
|
|||
async def save(self, driver: GraphDriver):
|
||||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
|
||||
'community_name',
|
||||
'communities',
|
||||
[{'name': self.name, 'uuid': self.uuid, 'group_id': self.group_id}],
|
||||
)
|
||||
result = await driver.execute_query(
|
||||
|
|
|
|||
|
|
@ -54,6 +54,16 @@ class SearchFilters(BaseModel):
|
|||
expired_at: list[list[DateFilter]] | None = Field(default=None)
|
||||
|
||||
|
||||
def cypher_to_opensearch_operator(op: ComparisonOperator) -> str:
|
||||
mapping = {
|
||||
ComparisonOperator.greater_than: 'gt',
|
||||
ComparisonOperator.less_than: 'lt',
|
||||
ComparisonOperator.greater_than_equal: 'gte',
|
||||
ComparisonOperator.less_than_equal: 'lte',
|
||||
}
|
||||
return mapping.get(op, op.value)
|
||||
|
||||
|
||||
def node_search_filter_query_constructor(
|
||||
filters: SearchFilters,
|
||||
provider: GraphProvider,
|
||||
|
|
@ -234,3 +244,38 @@ def edge_search_filter_query_constructor(
|
|||
filter_queries.append(expired_at_filter)
|
||||
|
||||
return filter_queries, filter_params
|
||||
|
||||
|
||||
def build_aoss_node_filters(group_ids: list[str], search_filters: SearchFilters) -> list[dict]:
|
||||
filters = [{'terms': {'group_id': group_ids}}]
|
||||
|
||||
if search_filters.node_labels:
|
||||
filters.append({'terms': {'node_labels': search_filters.node_labels}})
|
||||
|
||||
return filters
|
||||
|
||||
|
||||
def build_aoss_edge_filters(group_ids: list[str], search_filters: SearchFilters) -> list[dict]:
|
||||
filters: list[dict] = [{'terms': {'group_id': group_ids}}]
|
||||
|
||||
if search_filters.edge_types:
|
||||
filters.append({'terms': {'edge_types': search_filters.edge_types}})
|
||||
|
||||
for field in ['valid_at', 'invalid_at', 'created_at', 'expired_at']:
|
||||
ranges = getattr(search_filters, field)
|
||||
if ranges:
|
||||
# OR of ANDs
|
||||
should_clauses = []
|
||||
for and_group in ranges:
|
||||
and_filters = []
|
||||
for df in and_group: # df is a DateFilter
|
||||
range_query = {
|
||||
'range': {
|
||||
field: {cypher_to_opensearch_operator(df.comparison_operator): df.date}
|
||||
}
|
||||
}
|
||||
and_filters.append(range_query)
|
||||
should_clauses.append({'bool': {'filter': and_filters}})
|
||||
filters.append({'bool': {'should': should_clauses, 'minimum_should_match': 1}})
|
||||
|
||||
return filters
|
||||
|
|
|
|||
|
|
@ -51,6 +51,8 @@ from graphiti_core.nodes import (
|
|||
)
|
||||
from graphiti_core.search.search_filters import (
|
||||
SearchFilters,
|
||||
build_aoss_edge_filters,
|
||||
build_aoss_node_filters,
|
||||
edge_search_filter_query_constructor,
|
||||
node_search_filter_query_constructor,
|
||||
)
|
||||
|
|
@ -200,7 +202,6 @@ async def edge_fulltext_search(
|
|||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
res = driver.run_aoss_query('edge_name_and_fact', query) # pyright: ignore reportAttributeAccessIssue
|
||||
if res['hits']['total']['value'] > 0:
|
||||
# Calculate Cosine similarity then return the edge ids
|
||||
input_ids = []
|
||||
for r in res['hits']['hits']:
|
||||
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
|
||||
|
|
@ -208,11 +209,11 @@ async def edge_fulltext_search(
|
|||
# Match the edge ids and return the values
|
||||
query = (
|
||||
"""
|
||||
UNWIND $ids as id
|
||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||
WHERE e.group_id IN $group_ids
|
||||
AND id(e)=id
|
||||
"""
|
||||
UNWIND $ids as id
|
||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||
WHERE e.group_id IN $group_ids
|
||||
AND id(e)=id
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
AND id(e)=id
|
||||
|
|
@ -244,6 +245,31 @@ async def edge_fulltext_search(
|
|||
)
|
||||
else:
|
||||
return []
|
||||
elif driver.aoss_client:
|
||||
route = group_ids[0] if group_ids else None
|
||||
filters = build_aoss_edge_filters(group_ids or [], search_filter)
|
||||
res = driver.aoss_client.search(
|
||||
index='entity_edges',
|
||||
routing=route,
|
||||
_source=['uuid'],
|
||||
query={
|
||||
'bool': {
|
||||
'filter': filters,
|
||||
'must': [{'match': {'fact': {'query': query, 'operator': 'or'}}}],
|
||||
}
|
||||
},
|
||||
)
|
||||
if res['hits']['total']['value'] > 0:
|
||||
input_uuids = {}
|
||||
for r in res['hits']['hits']:
|
||||
input_uuids[r['_source']['uuid']] = r['_score']
|
||||
|
||||
# Get edges
|
||||
entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys()))
|
||||
entity_edges.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
|
||||
return entity_edges
|
||||
else:
|
||||
return []
|
||||
else:
|
||||
query = (
|
||||
get_relationships_query('edge_name_and_fact', limit=limit, provider=driver.provider)
|
||||
|
|
@ -318,8 +344,8 @@ async def edge_similarity_search(
|
|||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
query = (
|
||||
"""
|
||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||
"""
|
||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
|
||||
|
|
@ -377,6 +403,32 @@ async def edge_similarity_search(
|
|||
)
|
||||
else:
|
||||
return []
|
||||
elif driver.aoss_client:
|
||||
route = group_ids[0] if group_ids else None
|
||||
filters = build_aoss_edge_filters(group_ids or [], search_filter)
|
||||
res = driver.aoss_client.search(
|
||||
index='entity_edges',
|
||||
routing=route,
|
||||
_source=['uuid'],
|
||||
knn={
|
||||
'field': 'fact_embedding',
|
||||
'query_vector': search_vector,
|
||||
'k': limit,
|
||||
'num_candidates': 1000,
|
||||
},
|
||||
query={'bool': {'filter': filters}},
|
||||
)
|
||||
|
||||
if res['hits']['total']['value'] > 0:
|
||||
input_uuids = {}
|
||||
for r in res['hits']['hits']:
|
||||
input_uuids[r['_source']['uuid']] = r['_score']
|
||||
|
||||
# Get edges
|
||||
entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys()))
|
||||
entity_edges.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
|
||||
return entity_edges
|
||||
|
||||
else:
|
||||
query = (
|
||||
match_query
|
||||
|
|
@ -563,7 +615,6 @@ async def node_fulltext_search(
|
|||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
res = driver.run_aoss_query('node_name_and_summary', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
|
||||
if res['hits']['total']['value'] > 0:
|
||||
# Calculate Cosine similarity then return the edge ids
|
||||
input_ids = []
|
||||
for r in res['hits']['hits']:
|
||||
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
|
||||
|
|
@ -571,11 +622,11 @@ async def node_fulltext_search(
|
|||
# Match the edge ides and return the values
|
||||
query = (
|
||||
"""
|
||||
UNWIND $ids as i
|
||||
MATCH (n:Entity)
|
||||
WHERE n.uuid=i.id
|
||||
RETURN
|
||||
"""
|
||||
UNWIND $ids as i
|
||||
MATCH (n:Entity)
|
||||
WHERE n.uuid=i.id
|
||||
RETURN
|
||||
"""
|
||||
+ get_entity_node_return_query(driver.provider)
|
||||
+ """
|
||||
ORDER BY i.score DESC
|
||||
|
|
@ -592,6 +643,41 @@ async def node_fulltext_search(
|
|||
)
|
||||
else:
|
||||
return []
|
||||
elif driver.aoss_client:
|
||||
route = group_ids[0] if group_ids else None
|
||||
filters = build_aoss_node_filters(group_ids or [], search_filter)
|
||||
res = driver.aoss_client.search(
|
||||
'entities',
|
||||
routing=route,
|
||||
_source=['uuid'],
|
||||
query={
|
||||
'bool': {
|
||||
'filter': filters,
|
||||
'must': [
|
||||
{
|
||||
'multi_match': {
|
||||
'query': query,
|
||||
'field': ['name', 'summary'],
|
||||
'operator': 'or',
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
},
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
if res['hits']['total']['value'] > 0:
|
||||
input_uuids = {}
|
||||
for r in res['hits']['hits']:
|
||||
input_uuids[r['_source']['uuid']] = r['_score']
|
||||
|
||||
# Get nodes
|
||||
entities = await EntityNode.get_by_uuids(driver, list(input_uuids.keys()))
|
||||
entities.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
|
||||
return entities
|
||||
else:
|
||||
return []
|
||||
else:
|
||||
query = (
|
||||
get_nodes_query(
|
||||
|
|
@ -648,8 +734,8 @@ async def node_similarity_search(
|
|||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
query = (
|
||||
"""
|
||||
MATCH (n:Entity)
|
||||
"""
|
||||
MATCH (n:Entity)
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
||||
|
|
@ -678,11 +764,11 @@ async def node_similarity_search(
|
|||
# Match the edge ides and return the values
|
||||
query = (
|
||||
"""
|
||||
UNWIND $ids as i
|
||||
MATCH (n:Entity)
|
||||
WHERE id(n)=i.id
|
||||
RETURN
|
||||
"""
|
||||
UNWIND $ids as i
|
||||
MATCH (n:Entity)
|
||||
WHERE id(n)=i.id
|
||||
RETURN
|
||||
"""
|
||||
+ get_entity_node_return_query(driver.provider)
|
||||
+ """
|
||||
ORDER BY i.score DESC
|
||||
|
|
@ -700,11 +786,36 @@ async def node_similarity_search(
|
|||
)
|
||||
else:
|
||||
return []
|
||||
elif driver.aoss_client:
|
||||
route = group_ids[0] if group_ids else None
|
||||
filters = build_aoss_node_filters(group_ids or [], search_filter)
|
||||
res = driver.aoss_client.search(
|
||||
index='entities',
|
||||
routing=route,
|
||||
_source=['uuid'],
|
||||
knn={
|
||||
'field': 'fact_embedding',
|
||||
'query_vector': search_vector,
|
||||
'k': limit,
|
||||
'num_candidates': 1000,
|
||||
},
|
||||
query={'bool': {'filter': filters}},
|
||||
)
|
||||
|
||||
if res['hits']['total']['value'] > 0:
|
||||
input_uuids = {}
|
||||
for r in res['hits']['hits']:
|
||||
input_uuids[r['_source']['uuid']] = r['_score']
|
||||
|
||||
# Get edges
|
||||
entity_nodes = await EntityNode.get_by_uuids(driver, list(input_uuids.keys()))
|
||||
entity_nodes.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
|
||||
return entity_nodes
|
||||
else:
|
||||
query = (
|
||||
"""
|
||||
MATCH (n:Entity)
|
||||
"""
|
||||
MATCH (n:Entity)
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH n, """
|
||||
|
|
@ -843,7 +954,6 @@ async def episode_fulltext_search(
|
|||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
res = driver.run_aoss_query('episode_content', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
|
||||
if res['hits']['total']['value'] > 0:
|
||||
# Calculate Cosine similarity then return the edge ids
|
||||
input_ids = []
|
||||
for r in res['hits']['hits']:
|
||||
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
|
||||
|
|
@ -852,7 +962,7 @@ async def episode_fulltext_search(
|
|||
query = """
|
||||
UNWIND $ids as i
|
||||
MATCH (e:Episodic)
|
||||
WHERE e.uuid=i.id
|
||||
WHERE e.uuid=i.uuid
|
||||
RETURN
|
||||
e.content AS content,
|
||||
e.created_at AS created_at,
|
||||
|
|
@ -876,6 +986,40 @@ async def episode_fulltext_search(
|
|||
)
|
||||
else:
|
||||
return []
|
||||
elif driver.aoss_client:
|
||||
route = group_ids[0] if group_ids else None
|
||||
res = driver.aoss_client.search(
|
||||
'episodes',
|
||||
routing=route,
|
||||
_source=['uuid'],
|
||||
query={
|
||||
'bool': {
|
||||
'filter': {'terms': group_ids},
|
||||
'must': [
|
||||
{
|
||||
'multi_match': {
|
||||
'query': query,
|
||||
'field': ['name', 'content'],
|
||||
'operator': 'or',
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
},
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
if res['hits']['total']['value'] > 0:
|
||||
input_uuids = {}
|
||||
for r in res['hits']['hits']:
|
||||
input_uuids[r['_source']['uuid']] = r['_score']
|
||||
|
||||
# Get nodes
|
||||
episodes = await EpisodicNode.get_by_uuids(driver, list(input_uuids.keys()))
|
||||
episodes.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
|
||||
return episodes
|
||||
else:
|
||||
return []
|
||||
else:
|
||||
query = (
|
||||
get_nodes_query('episode_content', '$query', limit=limit, provider=driver.provider)
|
||||
|
|
@ -1003,8 +1147,8 @@ async def community_similarity_search(
|
|||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
query = (
|
||||
"""
|
||||
MATCH (n:Community)
|
||||
"""
|
||||
MATCH (n:Community)
|
||||
"""
|
||||
+ group_filter_query
|
||||
+ """
|
||||
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
||||
|
|
@ -1063,8 +1207,8 @@ async def community_similarity_search(
|
|||
|
||||
query = (
|
||||
"""
|
||||
MATCH (c:Community)
|
||||
"""
|
||||
MATCH (c:Community)
|
||||
"""
|
||||
+ group_filter_query
|
||||
+ """
|
||||
WITH c,
|
||||
|
|
@ -1206,9 +1350,9 @@ async def get_relevant_nodes(
|
|||
# FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
|
||||
query = (
|
||||
"""
|
||||
UNWIND $nodes AS node
|
||||
MATCH (n:Entity {group_id: $group_id})
|
||||
"""
|
||||
UNWIND $nodes AS node
|
||||
MATCH (n:Entity {group_id: $group_id})
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH node, n, """
|
||||
|
|
@ -1253,9 +1397,9 @@ async def get_relevant_nodes(
|
|||
else:
|
||||
query = (
|
||||
"""
|
||||
UNWIND $nodes AS node
|
||||
MATCH (n:Entity {group_id: $group_id})
|
||||
"""
|
||||
UNWIND $nodes AS node
|
||||
MATCH (n:Entity {group_id: $group_id})
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH node, n, """
|
||||
|
|
@ -1344,9 +1488,9 @@ async def get_relevant_edges(
|
|||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
query = (
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH e, edge
|
||||
|
|
@ -1416,9 +1560,9 @@ async def get_relevant_edges(
|
|||
|
||||
query = (
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH e, edge, n, m, """
|
||||
|
|
@ -1454,9 +1598,9 @@ async def get_relevant_edges(
|
|||
else:
|
||||
query = (
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH e, edge, """
|
||||
|
|
@ -1529,10 +1673,10 @@ async def get_edge_invalidation_candidates(
|
|||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
query = (
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
||||
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
||||
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH e, edge
|
||||
|
|
@ -1602,10 +1746,10 @@ async def get_edge_invalidation_candidates(
|
|||
|
||||
query = (
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
|
||||
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
|
||||
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH edge, e, n, m, """
|
||||
|
|
@ -1641,10 +1785,10 @@ async def get_edge_invalidation_candidates(
|
|||
else:
|
||||
query = (
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
||||
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
||||
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH edge, e, """
|
||||
|
|
|
|||
|
|
@ -187,12 +187,25 @@ async def add_nodes_and_edges_bulk_tx(
|
|||
await tx.run(episodic_edge_query, **edge.model_dump())
|
||||
else:
|
||||
await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
|
||||
await tx.run(get_entity_node_save_bulk_query(driver.provider, nodes), nodes=nodes)
|
||||
await tx.run(
|
||||
get_entity_node_save_bulk_query(driver.provider, nodes),
|
||||
nodes=nodes,
|
||||
has_aoss=bool(driver.aoss_client),
|
||||
)
|
||||
await tx.run(
|
||||
get_episodic_edge_save_bulk_query(driver.provider),
|
||||
episodic_edges=[edge.model_dump() for edge in episodic_edges],
|
||||
)
|
||||
await tx.run(get_entity_edge_save_bulk_query(driver.provider), entity_edges=edges)
|
||||
await tx.run(
|
||||
get_entity_edge_save_bulk_query(driver.provider),
|
||||
entity_edges=edges,
|
||||
has_aoss=bool(driver.aoss_client),
|
||||
)
|
||||
|
||||
if driver.aoss_client:
|
||||
driver.save_to_aoss('episodes', episodes)
|
||||
driver.save_to_aoss('entities', nodes)
|
||||
driver.save_to_aoss('entity_edges', edges)
|
||||
|
||||
|
||||
async def extract_nodes_and_edges_bulk(
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bool = False):
|
||||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
if driver.aoss_client:
|
||||
await driver.create_aoss_indices() # pyright: ignore[reportAttributeAccessIssue]
|
||||
return
|
||||
if delete_existing:
|
||||
|
|
@ -56,7 +56,9 @@ async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bo
|
|||
|
||||
range_indices: list[LiteralString] = get_range_indices(driver.provider)
|
||||
|
||||
fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider)
|
||||
# Don't create fulltext indices if OpenSearch is being used
|
||||
if not driver.aoss_client:
|
||||
fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider)
|
||||
|
||||
if driver.provider == GraphProvider.KUZU:
|
||||
# Skip creating fulltext indices if they already exist. Need to do this manually
|
||||
|
|
@ -149,9 +151,9 @@ async def retrieve_episodes(
|
|||
|
||||
query: LiteralString = (
|
||||
"""
|
||||
MATCH (e:Episodic)
|
||||
WHERE e.valid_at <= $reference_time
|
||||
"""
|
||||
MATCH (e:Episodic)
|
||||
WHERE e.valid_at <= $reference_time
|
||||
"""
|
||||
+ query_filter
|
||||
+ """
|
||||
RETURN
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
[project]
|
||||
name = "graphiti-core"
|
||||
description = "A temporal graph building library"
|
||||
version = "0.20.4"
|
||||
version = "0.21.0pre1"
|
||||
authors = [
|
||||
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
|
||||
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
|
||||
|
|
@ -32,6 +32,7 @@ google-genai = ["google-genai>=1.8.0"]
|
|||
kuzu = ["kuzu>=0.11.2"]
|
||||
falkordb = ["falkordb>=1.1.2,<2.0.0"]
|
||||
voyageai = ["voyageai>=0.2.3"]
|
||||
neo4j-opensearch = ["boto3>=1.39.16", "opensearch-py>=3.0.0"]
|
||||
sentence-transformers = ["sentence-transformers>=3.2.1"]
|
||||
neptune = ["langchain-aws>=0.2.29", "opensearch-py>=3.0.0", "boto3>=1.39.16"]
|
||||
dev = [
|
||||
|
|
|
|||
10
uv.lock
generated
10
uv.lock
generated
|
|
@ -783,7 +783,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "graphiti-core"
|
||||
version = "0.20.4"
|
||||
version = "0.21.0rc1"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "diskcache" },
|
||||
|
|
@ -835,6 +835,10 @@ groq = [
|
|||
kuzu = [
|
||||
{ name = "kuzu" },
|
||||
]
|
||||
neo4j-opensearch = [
|
||||
{ name = "boto3" },
|
||||
{ name = "opensearch-py" },
|
||||
]
|
||||
neptune = [
|
||||
{ name = "boto3" },
|
||||
{ name = "langchain-aws" },
|
||||
|
|
@ -851,6 +855,7 @@ voyageai = [
|
|||
requires-dist = [
|
||||
{ name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.49.0" },
|
||||
{ name = "anthropic", marker = "extra == 'dev'", specifier = ">=0.49.0" },
|
||||
{ name = "boto3", marker = "extra == 'neo4j-opensearch'", specifier = ">=1.39.16" },
|
||||
{ name = "boto3", marker = "extra == 'neptune'", specifier = ">=1.39.16" },
|
||||
{ name = "diskcache", specifier = ">=5.6.3" },
|
||||
{ name = "diskcache-stubs", marker = "extra == 'dev'", specifier = ">=5.6.3.6.20240818" },
|
||||
|
|
@ -872,6 +877,7 @@ requires-dist = [
|
|||
{ name = "neo4j", specifier = ">=5.26.0" },
|
||||
{ name = "numpy", specifier = ">=1.0.0" },
|
||||
{ name = "openai", specifier = ">=1.91.0" },
|
||||
{ name = "opensearch-py", marker = "extra == 'neo4j-opensearch'", specifier = ">=3.0.0" },
|
||||
{ name = "opensearch-py", marker = "extra == 'neptune'", specifier = ">=3.0.0" },
|
||||
{ name = "posthog", specifier = ">=3.0.0" },
|
||||
{ name = "pydantic", specifier = ">=2.11.5" },
|
||||
|
|
@ -888,7 +894,7 @@ requires-dist = [
|
|||
{ name = "voyageai", marker = "extra == 'dev'", specifier = ">=0.2.3" },
|
||||
{ name = "voyageai", marker = "extra == 'voyageai'", specifier = ">=0.2.3" },
|
||||
]
|
||||
provides-extras = ["anthropic", "groq", "google-genai", "kuzu", "falkordb", "voyageai", "sentence-transformers", "neptune", "dev"]
|
||||
provides-extras = ["anthropic", "groq", "google-genai", "kuzu", "falkordb", "voyageai", "neo4j-opensearch", "sentence-transformers", "neptune", "dev"]
|
||||
|
||||
[[package]]
|
||||
name = "groq"
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue