OpenSearch Integration for Neo4j (#896)

* move aoss to driver

* add indexes

* don't save vectors to neo4j with aoss

* load embeddings from aoss

* add group_id routing

* add search filters and similarity search

* neptune regression update

* update neptune for regression purposes

* update index creation with aliasing

* regression tested

* update version

* edits

* claude suggestions

* cleanup

* updates

* add embedding dim env var

* use cosine sim

* updates

* updates

* remove unused imports

* update
This commit is contained in:
Preston Rasmussen 2025-09-09 10:51:46 -04:00 committed by GitHub
parent a3479758d5
commit 0884cc00e5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 634 additions and 160 deletions

View file

@ -14,15 +14,30 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import asyncio
import copy
import logging
from abc import ABC, abstractmethod
from collections.abc import Coroutine
from datetime import datetime
from enum import Enum
from typing import Any
from graphiti_core.embedder.client import EMBEDDING_DIM
try:
from opensearchpy import OpenSearch, helpers
_HAS_OPENSEARCH = True
except ImportError:
OpenSearch = None
helpers = None
_HAS_OPENSEARCH = False
logger = logging.getLogger(__name__)
DEFAULT_SIZE = 10
class GraphProvider(Enum):
NEO4J = 'neo4j'
@ -31,6 +46,93 @@ class GraphProvider(Enum):
NEPTUNE = 'neptune'
aoss_indices = [
{
'index_name': 'entities',
'body': {
'mappings': {
'properties': {
'uuid': {'type': 'keyword'},
'name': {'type': 'text'},
'summary': {'type': 'text'},
'group_id': {'type': 'text'},
'created_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
'name_embedding': {
'type': 'knn_vector',
'dims': EMBEDDING_DIM,
'index': True,
'similarity': 'cosine',
'method': {
'engine': 'faiss',
'space_type': 'cosinesimil',
'name': 'hnsw',
'parameters': {'ef_construction': 128, 'm': 16},
},
},
}
}
},
},
{
'index_name': 'communities',
'body': {
'mappings': {
'properties': {
'uuid': {'type': 'keyword'},
'name': {'type': 'text'},
'group_id': {'type': 'text'},
}
}
},
},
{
'index_name': 'episodes',
'body': {
'mappings': {
'properties': {
'uuid': {'type': 'keyword'},
'content': {'type': 'text'},
'source': {'type': 'text'},
'source_description': {'type': 'text'},
'group_id': {'type': 'text'},
'created_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
'valid_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
}
}
},
},
{
'index_name': 'entity_edges',
'body': {
'mappings': {
'properties': {
'uuid': {'type': 'keyword'},
'name': {'type': 'text'},
'fact': {'type': 'text'},
'group_id': {'type': 'text'},
'created_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
'valid_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
'expired_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
'invalid_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
'fact_embedding': {
'type': 'knn_vector',
'dims': EMBEDDING_DIM,
'index': True,
'similarity': 'cosine',
'method': {
'engine': 'faiss',
'space_type': 'cosinesimil',
'name': 'hnsw',
'parameters': {'ef_construction': 128, 'm': 16},
},
},
}
}
},
},
]
class GraphDriverSession(ABC):
provider: GraphProvider
@ -61,6 +163,7 @@ class GraphDriver(ABC):
'' # Neo4j (default) syntax does not require a prefix for fulltext queries
)
_database: str
aoss_client: OpenSearch | None # type: ignore
@abstractmethod
def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
@ -87,3 +190,70 @@ class GraphDriver(ABC):
cloned._database = database
return cloned
async def delete_all_indexes_impl(self) -> Coroutine[Any, Any, Any]:
# No matter what happens above, always return True
return self.delete_aoss_indices()
async def create_aoss_indices(self):
client = self.aoss_client
if not client:
logger.warning('No OpenSearch client found')
return
for index in aoss_indices:
alias_name = index['index_name']
# If alias already exists, skip (idempotent behavior)
if client.indices.exists_alias(name=alias_name):
continue
# Build a physical index name with timestamp
ts_suffix = datetime.utcnow().strftime('%Y%m%d%H%M%S')
physical_index_name = f'{alias_name}_{ts_suffix}'
# Create the index
client.indices.create(index=physical_index_name, body=index['body'])
# Point alias to it
client.indices.put_alias(index=physical_index_name, name=alias_name)
# Allow some time for index creation
await asyncio.sleep(60)
async def delete_aoss_indices(self):
for index in aoss_indices:
index_name = index['index_name']
client = self.aoss_client
if not client:
logger.warning('No OpenSearch client found')
return
if client.indices.exists(index=index_name):
client.indices.delete(index=index_name)
def save_to_aoss(self, name: str, data: list[dict]) -> int:
client = self.aoss_client
if not client or not helpers:
logger.warning('No OpenSearch client found')
return 0
for index in aoss_indices:
if name.lower() == index['index_name']:
to_index = []
for d in data:
item = {
'_index': name,
'_routing': d.get('group_id'), # shard routing
}
for p in index['body']['mappings']['properties']:
if p in d: # protect against missing fields
item[p] = d[p]
to_index.append(item)
success, failed = helpers.bulk(client, to_index, stats_only=True)
return success if failed == 0 else success
return 0

View file

@ -74,6 +74,7 @@ class FalkorDriverSession(GraphDriverSession):
class FalkorDriver(GraphDriver):
provider = GraphProvider.FALKORDB
aoss_client: None = None
def __init__(
self,

View file

@ -92,6 +92,7 @@ SCHEMA_QUERIES = """
class KuzuDriver(GraphDriver):
provider: GraphProvider = GraphProvider.KUZU
aoss_client: None = None
def __init__(
self,

View file

@ -22,14 +22,35 @@ from neo4j import AsyncGraphDatabase, EagerResult
from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
from graphiti_core.helpers import semaphore_gather
logger = logging.getLogger(__name__)
try:
import boto3
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection
_HAS_OPENSEARCH = True
except ImportError:
boto3 = None
OpenSearch = None
Urllib3AWSV4SignerAuth = None
Urllib3HttpConnection = None
_HAS_OPENSEARCH = False
class Neo4jDriver(GraphDriver):
provider = GraphProvider.NEO4J
def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'neo4j'):
def __init__(
self,
uri: str,
user: str | None,
password: str | None,
database: str = 'neo4j',
aoss_host: str | None = None,
aoss_port: int | None = None,
):
super().__init__()
self.client = AsyncGraphDatabase.driver(
uri=uri,
@ -37,6 +58,24 @@ class Neo4jDriver(GraphDriver):
)
self._database = database
self.aoss_client = None
if aoss_host and aoss_port and boto3 is not None:
try:
session = boto3.Session()
self.aoss_client = OpenSearch( # type: ignore
hosts=[{'host': aoss_host, 'port': aoss_port}],
http_auth=Urllib3AWSV4SignerAuth( # type: ignore
session.get_credentials(), session.region_name, 'aoss'
),
use_ssl=True,
verify_certs=True,
connection_class=Urllib3HttpConnection,
pool_maxsize=20,
) # type: ignore
except Exception as e:
logger.warning(f'Failed to initialize OpenSearch client: {e}')
self.aoss_client = None
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult:
# Check if database_ is provided in kwargs.
# If not populated, set the value to retain backwards compatibility
@ -60,7 +99,14 @@ class Neo4jDriver(GraphDriver):
async def close(self) -> None:
return await self.client.close()
def delete_all_indexes(self) -> Coroutine[Any, Any, EagerResult]:
def delete_all_indexes(self) -> Coroutine:
if self.aoss_client:
return semaphore_gather(
self.client.execute_query(
'CALL db.indexes() YIELD name DROP INDEX name',
),
self.delete_aoss_indices(),
)
return self.client.execute_query(
'CALL db.indexes() YIELD name DROP INDEX name',
)

View file

@ -22,16 +22,21 @@ from typing import Any
import boto3
from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
from graphiti_core.driver.driver import (
DEFAULT_SIZE,
GraphDriver,
GraphDriverSession,
GraphProvider,
)
logger = logging.getLogger(__name__)
DEFAULT_SIZE = 10
aoss_indices = [
neptune_aoss_indices = [
{
'index_name': 'node_name_and_summary',
'alias_name': 'entities',
'body': {
'mappings': {
'properties': {
@ -49,6 +54,7 @@ aoss_indices = [
},
{
'index_name': 'community_name',
'alias_name': 'communities',
'body': {
'mappings': {
'properties': {
@ -65,6 +71,7 @@ aoss_indices = [
},
{
'index_name': 'episode_content',
'alias_name': 'episodes',
'body': {
'mappings': {
'properties': {
@ -88,6 +95,7 @@ aoss_indices = [
},
{
'index_name': 'edge_name_and_fact',
'alias_name': 'facts',
'body': {
'mappings': {
'properties': {
@ -220,54 +228,27 @@ class NeptuneDriver(GraphDriver):
async def _delete_all_data(self) -> Any:
return await self.execute_query('MATCH (n) DETACH DELETE n')
def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
return self.delete_all_indexes_impl()
async def delete_all_indexes_impl(self) -> Coroutine[Any, Any, Any]:
# No matter what happens above, always return True
return self.delete_aoss_indices()
async def create_aoss_indices(self):
for index in aoss_indices:
for index in neptune_aoss_indices:
index_name = index['index_name']
client = self.aoss_client
if not client:
raise ValueError(
'You must provide an AOSS endpoint to create an OpenSearch driver.'
)
if not client.indices.exists(index=index_name):
client.indices.create(index=index_name, body=index['body'])
alias_name = index.get('alias_name', index_name)
if not client.indices.exists_alias(name=alias_name, index=index_name):
client.indices.put_alias(index=index_name, name=alias_name)
# Sleep for 1 minute to let the index creation complete
await asyncio.sleep(60)
async def delete_aoss_indices(self):
for index in aoss_indices:
index_name = index['index_name']
client = self.aoss_client
if client.indices.exists(index=index_name):
client.indices.delete(index=index_name)
def run_aoss_query(self, name: str, query_text: str, limit: int = 10) -> dict[str, Any]:
for index in aoss_indices:
if name.lower() == index['index_name']:
index['query']['query']['multi_match']['query'] = query_text
query = {'size': limit, 'query': index['query']}
resp = self.aoss_client.search(body=query['query'], index=index['index_name'])
return resp
return {}
def save_to_aoss(self, name: str, data: list[dict]) -> int:
for index in aoss_indices:
if name.lower() == index['index_name']:
to_index = []
for d in data:
item = {'_index': name}
for p in index['body']['mappings']['properties']:
item[p] = d[p]
to_index.append(item)
success, failed = helpers.bulk(self.aoss_client, to_index, stats_only=True)
if failed > 0:
return success
else:
return 0
return 0
def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
return self.delete_all_indexes_impl()
class NeptuneDriverSession(GraphDriverSession):

View file

@ -255,6 +255,21 @@ class EntityEdge(Edge):
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
"""
elif driver.aoss_client:
resp = driver.aoss_client.search(
body={
'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
'size': 1,
},
index='entity_edges',
routing=self.group_id,
)
if resp['hits']['hits']:
self.fact_embedding = resp['hits']['hits'][0]['_source']['fact_embedding']
return
else:
raise EdgeNotFoundError(self.uuid)
if driver.provider == GraphProvider.KUZU:
query = """
@ -292,14 +307,14 @@ class EntityEdge(Edge):
if driver.provider == GraphProvider.KUZU:
edge_data['attributes'] = json.dumps(self.attributes)
result = await driver.execute_query(
get_entity_edge_save_query(driver.provider),
get_entity_edge_save_query(driver.provider, has_aoss=bool(driver.aoss_client)),
**edge_data,
)
else:
edge_data.update(self.attributes or {})
if driver.provider == GraphProvider.NEPTUNE:
driver.save_to_aoss('edge_name_and_fact', [edge_data]) # pyright: ignore reportAttributeAccessIssue
if driver.aoss_client:
driver.save_to_aoss('entity_edges', [edge_data]) # pyright: ignore reportAttributeAccessIssue
result = await driver.execute_query(
get_entity_edge_save_query(driver.provider),

View file

@ -14,12 +14,13 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import os
from abc import ABC, abstractmethod
from collections.abc import Iterable
from pydantic import BaseModel, Field
EMBEDDING_DIM = 1024
EMBEDDING_DIM = int(os.getenv('EMBEDDING_DIM', 1024))
class EmbedderConfig(BaseModel):

View file

@ -60,7 +60,7 @@ EPISODIC_EDGE_RETURN = """
"""
def get_entity_edge_save_query(provider: GraphProvider) -> str:
def get_entity_edge_save_query(provider: GraphProvider, has_aoss: bool = False) -> str:
match provider:
case GraphProvider.FALKORDB:
return """
@ -99,17 +99,28 @@ def get_entity_edge_save_query(provider: GraphProvider) -> str:
RETURN e.uuid AS uuid
"""
case _: # Neo4j
return """
MATCH (source:Entity {uuid: $edge_data.source_uuid})
MATCH (target:Entity {uuid: $edge_data.target_uuid})
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
SET e = $edge_data
WITH e CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", $edge_data.fact_embedding)
save_embedding_query = (
"""WITH e CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", $edge_data.fact_embedding)"""
if not has_aoss
else ''
)
return (
(
"""
MATCH (source:Entity {uuid: $edge_data.source_uuid})
MATCH (target:Entity {uuid: $edge_data.target_uuid})
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
SET e = $edge_data
"""
+ save_embedding_query
)
+ """
RETURN e.uuid AS uuid
"""
"""
)
def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str:
def get_entity_edge_save_bulk_query(provider: GraphProvider, has_aoss: bool = False) -> str:
match provider:
case GraphProvider.FALKORDB:
return """
@ -152,15 +163,24 @@ def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str:
RETURN e.uuid AS uuid
"""
case _:
return """
UNWIND $entity_edges AS edge
MATCH (source:Entity {uuid: edge.source_node_uuid})
MATCH (target:Entity {uuid: edge.target_node_uuid})
MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target)
SET e = edge
WITH e, edge CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", edge.fact_embedding)
save_embedding_query = (
'WITH e, edge CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", edge.fact_embedding)'
if not has_aoss
else ''
)
return (
"""
UNWIND $entity_edges AS edge
MATCH (source:Entity {uuid: edge.source_node_uuid})
MATCH (target:Entity {uuid: edge.target_node_uuid})
MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target)
SET e = edge
"""
+ save_embedding_query
+ """
RETURN edge.uuid AS uuid
"""
)
def get_entity_edge_return_query(provider: GraphProvider) -> str:

View file

@ -126,7 +126,7 @@ EPISODIC_NODE_RETURN_NEPTUNE = """
"""
def get_entity_node_save_query(provider: GraphProvider, labels: str) -> str:
def get_entity_node_save_query(provider: GraphProvider, labels: str, has_aoss: bool = False) -> str:
match provider:
case GraphProvider.FALKORDB:
return f"""
@ -161,16 +161,27 @@ def get_entity_node_save_query(provider: GraphProvider, labels: str) -> str:
RETURN n.uuid AS uuid
"""
case _:
return f"""
save_embedding_query = (
'WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding)'
if not has_aoss
else ''
)
return (
f"""
MERGE (n:Entity {{uuid: $entity_data.uuid}})
SET n:{labels}
SET n = $entity_data
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding)
"""
+ save_embedding_query
+ """
RETURN n.uuid AS uuid
"""
)
def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict]) -> str | Any:
def get_entity_node_save_bulk_query(
provider: GraphProvider, nodes: list[dict], has_aoss: bool = False
) -> str | Any:
match provider:
case GraphProvider.FALKORDB:
queries = []
@ -222,14 +233,23 @@ def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict])
RETURN n.uuid AS uuid
"""
case _: # Neo4j
return """
UNWIND $nodes AS node
MERGE (n:Entity {uuid: node.uuid})
SET n:$(node.labels)
SET n = node
WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)
save_embedding_query = (
'WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)'
if not has_aoss
else ''
)
return (
"""
UNWIND $nodes AS node
MERGE (n:Entity {uuid: node.uuid})
SET n:$(node.labels)
SET n = node
"""
+ save_embedding_query
+ """
RETURN n.uuid AS uuid
"""
)
def get_entity_node_return_query(provider: GraphProvider) -> str:

View file

@ -273,20 +273,6 @@ class EpisodicNode(Node):
)
async def save(self, driver: GraphDriver):
if driver.provider == GraphProvider.NEPTUNE:
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
'episode_content',
[
{
'uuid': self.uuid,
'group_id': self.group_id,
'source': self.source.value,
'content': self.content,
'source_description': self.source_description,
}
],
)
episode_args = {
'uuid': self.uuid,
'name': self.name,
@ -299,6 +285,12 @@ class EpisodicNode(Node):
'source': self.source.value,
}
if driver.aoss_client:
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
'episodes',
[episode_args],
)
result = await driver.execute_query(
get_episode_node_save_query(driver.provider), **episode_args
)
@ -433,6 +425,22 @@ class EntityNode(Node):
MATCH (n:Entity {uuid: $uuid})
RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding
"""
elif driver.aoss_client:
resp = driver.aoss_client.search(
body={
'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
'size': 1,
},
index='entities',
routing=self.group_id,
)
if resp['hits']['hits']:
self.name_embedding = resp['hits']['hits'][0]['_source']['name_embedding']
return
else:
raise NodeNotFoundError(self.uuid)
else:
query: LiteralString = """
MATCH (n:Entity {uuid: $uuid})
@ -470,11 +478,11 @@ class EntityNode(Node):
entity_data.update(self.attributes or {})
labels = ':'.join(self.labels + ['Entity'])
if driver.provider == GraphProvider.NEPTUNE:
driver.save_to_aoss('node_name_and_summary', [entity_data]) # pyright: ignore reportAttributeAccessIssue
if driver.aoss_client:
driver.save_to_aoss('entities', [entity_data]) # pyright: ignore reportAttributeAccessIssue
result = await driver.execute_query(
get_entity_node_save_query(driver.provider, labels),
get_entity_node_save_query(driver.provider, labels, bool(driver.aoss_client)),
entity_data=entity_data,
)
@ -570,7 +578,7 @@ class CommunityNode(Node):
async def save(self, driver: GraphDriver):
if driver.provider == GraphProvider.NEPTUNE:
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
'community_name',
'communities',
[{'name': self.name, 'uuid': self.uuid, 'group_id': self.group_id}],
)
result = await driver.execute_query(

View file

@ -54,6 +54,16 @@ class SearchFilters(BaseModel):
expired_at: list[list[DateFilter]] | None = Field(default=None)
def cypher_to_opensearch_operator(op: ComparisonOperator) -> str:
mapping = {
ComparisonOperator.greater_than: 'gt',
ComparisonOperator.less_than: 'lt',
ComparisonOperator.greater_than_equal: 'gte',
ComparisonOperator.less_than_equal: 'lte',
}
return mapping.get(op, op.value)
def node_search_filter_query_constructor(
filters: SearchFilters,
provider: GraphProvider,
@ -234,3 +244,38 @@ def edge_search_filter_query_constructor(
filter_queries.append(expired_at_filter)
return filter_queries, filter_params
def build_aoss_node_filters(group_ids: list[str], search_filters: SearchFilters) -> list[dict]:
filters = [{'terms': {'group_id': group_ids}}]
if search_filters.node_labels:
filters.append({'terms': {'node_labels': search_filters.node_labels}})
return filters
def build_aoss_edge_filters(group_ids: list[str], search_filters: SearchFilters) -> list[dict]:
filters: list[dict] = [{'terms': {'group_id': group_ids}}]
if search_filters.edge_types:
filters.append({'terms': {'edge_types': search_filters.edge_types}})
for field in ['valid_at', 'invalid_at', 'created_at', 'expired_at']:
ranges = getattr(search_filters, field)
if ranges:
# OR of ANDs
should_clauses = []
for and_group in ranges:
and_filters = []
for df in and_group: # df is a DateFilter
range_query = {
'range': {
field: {cypher_to_opensearch_operator(df.comparison_operator): df.date}
}
}
and_filters.append(range_query)
should_clauses.append({'bool': {'filter': and_filters}})
filters.append({'bool': {'should': should_clauses, 'minimum_should_match': 1}})
return filters

View file

@ -51,6 +51,8 @@ from graphiti_core.nodes import (
)
from graphiti_core.search.search_filters import (
SearchFilters,
build_aoss_edge_filters,
build_aoss_node_filters,
edge_search_filter_query_constructor,
node_search_filter_query_constructor,
)
@ -200,7 +202,6 @@ async def edge_fulltext_search(
if driver.provider == GraphProvider.NEPTUNE:
res = driver.run_aoss_query('edge_name_and_fact', query) # pyright: ignore reportAttributeAccessIssue
if res['hits']['total']['value'] > 0:
# Calculate Cosine similarity then return the edge ids
input_ids = []
for r in res['hits']['hits']:
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
@ -208,11 +209,11 @@ async def edge_fulltext_search(
# Match the edge ids and return the values
query = (
"""
UNWIND $ids as id
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
WHERE e.group_id IN $group_ids
AND id(e)=id
"""
UNWIND $ids as id
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
WHERE e.group_id IN $group_ids
AND id(e)=id
"""
+ filter_query
+ """
AND id(e)=id
@ -244,6 +245,31 @@ async def edge_fulltext_search(
)
else:
return []
elif driver.aoss_client:
route = group_ids[0] if group_ids else None
filters = build_aoss_edge_filters(group_ids or [], search_filter)
res = driver.aoss_client.search(
index='entity_edges',
routing=route,
_source=['uuid'],
query={
'bool': {
'filter': filters,
'must': [{'match': {'fact': {'query': query, 'operator': 'or'}}}],
}
},
)
if res['hits']['total']['value'] > 0:
input_uuids = {}
for r in res['hits']['hits']:
input_uuids[r['_source']['uuid']] = r['_score']
# Get edges
entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys()))
entity_edges.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
return entity_edges
else:
return []
else:
query = (
get_relationships_query('edge_name_and_fact', limit=limit, provider=driver.provider)
@ -318,8 +344,8 @@ async def edge_similarity_search(
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
"""
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
"""
+ filter_query
+ """
RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
@ -377,6 +403,32 @@ async def edge_similarity_search(
)
else:
return []
elif driver.aoss_client:
route = group_ids[0] if group_ids else None
filters = build_aoss_edge_filters(group_ids or [], search_filter)
res = driver.aoss_client.search(
index='entity_edges',
routing=route,
_source=['uuid'],
knn={
'field': 'fact_embedding',
'query_vector': search_vector,
'k': limit,
'num_candidates': 1000,
},
query={'bool': {'filter': filters}},
)
if res['hits']['total']['value'] > 0:
input_uuids = {}
for r in res['hits']['hits']:
input_uuids[r['_source']['uuid']] = r['_score']
# Get edges
entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys()))
entity_edges.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
return entity_edges
else:
query = (
match_query
@ -563,7 +615,6 @@ async def node_fulltext_search(
if driver.provider == GraphProvider.NEPTUNE:
res = driver.run_aoss_query('node_name_and_summary', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
if res['hits']['total']['value'] > 0:
# Calculate Cosine similarity then return the edge ids
input_ids = []
for r in res['hits']['hits']:
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
@ -571,11 +622,11 @@ async def node_fulltext_search(
# Match the edge ides and return the values
query = (
"""
UNWIND $ids as i
MATCH (n:Entity)
WHERE n.uuid=i.id
RETURN
"""
UNWIND $ids as i
MATCH (n:Entity)
WHERE n.uuid=i.id
RETURN
"""
+ get_entity_node_return_query(driver.provider)
+ """
ORDER BY i.score DESC
@ -592,6 +643,41 @@ async def node_fulltext_search(
)
else:
return []
elif driver.aoss_client:
route = group_ids[0] if group_ids else None
filters = build_aoss_node_filters(group_ids or [], search_filter)
res = driver.aoss_client.search(
'entities',
routing=route,
_source=['uuid'],
query={
'bool': {
'filter': filters,
'must': [
{
'multi_match': {
'query': query,
'field': ['name', 'summary'],
'operator': 'or',
}
}
],
}
},
limit=limit,
)
if res['hits']['total']['value'] > 0:
input_uuids = {}
for r in res['hits']['hits']:
input_uuids[r['_source']['uuid']] = r['_score']
# Get nodes
entities = await EntityNode.get_by_uuids(driver, list(input_uuids.keys()))
entities.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
return entities
else:
return []
else:
query = (
get_nodes_query(
@ -648,8 +734,8 @@ async def node_similarity_search(
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
MATCH (n:Entity)
"""
MATCH (n:Entity)
"""
+ filter_query
+ """
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
@ -678,11 +764,11 @@ async def node_similarity_search(
# Match the edge ides and return the values
query = (
"""
UNWIND $ids as i
MATCH (n:Entity)
WHERE id(n)=i.id
RETURN
"""
UNWIND $ids as i
MATCH (n:Entity)
WHERE id(n)=i.id
RETURN
"""
+ get_entity_node_return_query(driver.provider)
+ """
ORDER BY i.score DESC
@ -700,11 +786,36 @@ async def node_similarity_search(
)
else:
return []
elif driver.aoss_client:
route = group_ids[0] if group_ids else None
filters = build_aoss_node_filters(group_ids or [], search_filter)
res = driver.aoss_client.search(
index='entities',
routing=route,
_source=['uuid'],
knn={
'field': 'fact_embedding',
'query_vector': search_vector,
'k': limit,
'num_candidates': 1000,
},
query={'bool': {'filter': filters}},
)
if res['hits']['total']['value'] > 0:
input_uuids = {}
for r in res['hits']['hits']:
input_uuids[r['_source']['uuid']] = r['_score']
# Get edges
entity_nodes = await EntityNode.get_by_uuids(driver, list(input_uuids.keys()))
entity_nodes.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
return entity_nodes
else:
query = (
"""
MATCH (n:Entity)
"""
MATCH (n:Entity)
"""
+ filter_query
+ """
WITH n, """
@ -843,7 +954,6 @@ async def episode_fulltext_search(
if driver.provider == GraphProvider.NEPTUNE:
res = driver.run_aoss_query('episode_content', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
if res['hits']['total']['value'] > 0:
# Calculate Cosine similarity then return the edge ids
input_ids = []
for r in res['hits']['hits']:
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
@ -852,7 +962,7 @@ async def episode_fulltext_search(
query = """
UNWIND $ids as i
MATCH (e:Episodic)
WHERE e.uuid=i.id
WHERE e.uuid=i.uuid
RETURN
e.content AS content,
e.created_at AS created_at,
@ -876,6 +986,40 @@ async def episode_fulltext_search(
)
else:
return []
elif driver.aoss_client:
route = group_ids[0] if group_ids else None
res = driver.aoss_client.search(
'episodes',
routing=route,
_source=['uuid'],
query={
'bool': {
'filter': {'terms': group_ids},
'must': [
{
'multi_match': {
'query': query,
'field': ['name', 'content'],
'operator': 'or',
}
}
],
}
},
limit=limit,
)
if res['hits']['total']['value'] > 0:
input_uuids = {}
for r in res['hits']['hits']:
input_uuids[r['_source']['uuid']] = r['_score']
# Get nodes
episodes = await EpisodicNode.get_by_uuids(driver, list(input_uuids.keys()))
episodes.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
return episodes
else:
return []
else:
query = (
get_nodes_query('episode_content', '$query', limit=limit, provider=driver.provider)
@ -1003,8 +1147,8 @@ async def community_similarity_search(
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
MATCH (n:Community)
"""
MATCH (n:Community)
"""
+ group_filter_query
+ """
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
@ -1063,8 +1207,8 @@ async def community_similarity_search(
query = (
"""
MATCH (c:Community)
"""
MATCH (c:Community)
"""
+ group_filter_query
+ """
WITH c,
@ -1206,9 +1350,9 @@ async def get_relevant_nodes(
# FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
query = (
"""
UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id})
"""
UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id})
"""
+ filter_query
+ """
WITH node, n, """
@ -1253,9 +1397,9 @@ async def get_relevant_nodes(
else:
query = (
"""
UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id})
"""
UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id})
"""
+ filter_query
+ """
WITH node, n, """
@ -1344,9 +1488,9 @@ async def get_relevant_edges(
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
"""
+ filter_query
+ """
WITH e, edge
@ -1416,9 +1560,9 @@ async def get_relevant_edges(
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
"""
+ filter_query
+ """
WITH e, edge, n, m, """
@ -1454,9 +1598,9 @@ async def get_relevant_edges(
else:
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
"""
+ filter_query
+ """
WITH e, edge, """
@ -1529,10 +1673,10 @@ async def get_edge_invalidation_candidates(
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
"""
+ filter_query
+ """
WITH e, edge
@ -1602,10 +1746,10 @@ async def get_edge_invalidation_candidates(
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
"""
+ filter_query
+ """
WITH edge, e, n, m, """
@ -1641,10 +1785,10 @@ async def get_edge_invalidation_candidates(
else:
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
"""
+ filter_query
+ """
WITH edge, e, """

View file

@ -187,12 +187,25 @@ async def add_nodes_and_edges_bulk_tx(
await tx.run(episodic_edge_query, **edge.model_dump())
else:
await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
await tx.run(get_entity_node_save_bulk_query(driver.provider, nodes), nodes=nodes)
await tx.run(
get_entity_node_save_bulk_query(driver.provider, nodes),
nodes=nodes,
has_aoss=bool(driver.aoss_client),
)
await tx.run(
get_episodic_edge_save_bulk_query(driver.provider),
episodic_edges=[edge.model_dump() for edge in episodic_edges],
)
await tx.run(get_entity_edge_save_bulk_query(driver.provider), entity_edges=edges)
await tx.run(
get_entity_edge_save_bulk_query(driver.provider),
entity_edges=edges,
has_aoss=bool(driver.aoss_client),
)
if driver.aoss_client:
driver.save_to_aoss('episodes', episodes)
driver.save_to_aoss('entities', nodes)
driver.save_to_aoss('entity_edges', edges)
async def extract_nodes_and_edges_bulk(

View file

@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bool = False):
if driver.provider == GraphProvider.NEPTUNE:
if driver.aoss_client:
await driver.create_aoss_indices() # pyright: ignore[reportAttributeAccessIssue]
return
if delete_existing:
@ -56,7 +56,9 @@ async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bo
range_indices: list[LiteralString] = get_range_indices(driver.provider)
fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider)
# Don't create fulltext indices if OpenSearch is being used
if not driver.aoss_client:
fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider)
if driver.provider == GraphProvider.KUZU:
# Skip creating fulltext indices if they already exist. Need to do this manually
@ -149,9 +151,9 @@ async def retrieve_episodes(
query: LiteralString = (
"""
MATCH (e:Episodic)
WHERE e.valid_at <= $reference_time
"""
MATCH (e:Episodic)
WHERE e.valid_at <= $reference_time
"""
+ query_filter
+ """
RETURN

View file

@ -1,7 +1,7 @@
[project]
name = "graphiti-core"
description = "A temporal graph building library"
version = "0.20.4"
version = "0.21.0pre1"
authors = [
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
@ -32,6 +32,7 @@ google-genai = ["google-genai>=1.8.0"]
kuzu = ["kuzu>=0.11.2"]
falkordb = ["falkordb>=1.1.2,<2.0.0"]
voyageai = ["voyageai>=0.2.3"]
neo4j-opensearch = ["boto3>=1.39.16", "opensearch-py>=3.0.0"]
sentence-transformers = ["sentence-transformers>=3.2.1"]
neptune = ["langchain-aws>=0.2.29", "opensearch-py>=3.0.0", "boto3>=1.39.16"]
dev = [

10
uv.lock generated
View file

@ -783,7 +783,7 @@ wheels = [
[[package]]
name = "graphiti-core"
version = "0.20.4"
version = "0.21.0rc1"
source = { editable = "." }
dependencies = [
{ name = "diskcache" },
@ -835,6 +835,10 @@ groq = [
kuzu = [
{ name = "kuzu" },
]
neo4j-opensearch = [
{ name = "boto3" },
{ name = "opensearch-py" },
]
neptune = [
{ name = "boto3" },
{ name = "langchain-aws" },
@ -851,6 +855,7 @@ voyageai = [
requires-dist = [
{ name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.49.0" },
{ name = "anthropic", marker = "extra == 'dev'", specifier = ">=0.49.0" },
{ name = "boto3", marker = "extra == 'neo4j-opensearch'", specifier = ">=1.39.16" },
{ name = "boto3", marker = "extra == 'neptune'", specifier = ">=1.39.16" },
{ name = "diskcache", specifier = ">=5.6.3" },
{ name = "diskcache-stubs", marker = "extra == 'dev'", specifier = ">=5.6.3.6.20240818" },
@ -872,6 +877,7 @@ requires-dist = [
{ name = "neo4j", specifier = ">=5.26.0" },
{ name = "numpy", specifier = ">=1.0.0" },
{ name = "openai", specifier = ">=1.91.0" },
{ name = "opensearch-py", marker = "extra == 'neo4j-opensearch'", specifier = ">=3.0.0" },
{ name = "opensearch-py", marker = "extra == 'neptune'", specifier = ">=3.0.0" },
{ name = "posthog", specifier = ">=3.0.0" },
{ name = "pydantic", specifier = ">=2.11.5" },
@ -888,7 +894,7 @@ requires-dist = [
{ name = "voyageai", marker = "extra == 'dev'", specifier = ">=0.2.3" },
{ name = "voyageai", marker = "extra == 'voyageai'", specifier = ">=0.2.3" },
]
provides-extras = ["anthropic", "groq", "google-genai", "kuzu", "falkordb", "voyageai", "sentence-transformers", "neptune", "dev"]
provides-extras = ["anthropic", "groq", "google-genai", "kuzu", "falkordb", "voyageai", "neo4j-opensearch", "sentence-transformers", "neptune", "dev"]
[[package]]
name = "groq"