OpenSearch Integration for Neo4j (#896)

* move aoss to driver

* add indexes

* don't save vectors to neo4j with aoss

* load embeddings from aoss

* add group_id routing

* add search filters and similarity search

* neptune regression update

* update neptune for regression purposes

* update index creation with aliasing

* regression tested

* update version

* edits

* claude suggestions

* cleanup

* updates

* add embedding dim env var

* use cosine sim

* updates

* updates

* remove unused imports

* update
This commit is contained in:
Preston Rasmussen 2025-09-09 10:51:46 -04:00 committed by GitHub
parent a3479758d5
commit 0884cc00e5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 634 additions and 160 deletions

View file

@ -14,15 +14,30 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
import asyncio
import copy import copy
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Coroutine from collections.abc import Coroutine
from datetime import datetime
from enum import Enum from enum import Enum
from typing import Any from typing import Any
from graphiti_core.embedder.client import EMBEDDING_DIM
try:
from opensearchpy import OpenSearch, helpers
_HAS_OPENSEARCH = True
except ImportError:
OpenSearch = None
helpers = None
_HAS_OPENSEARCH = False
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_SIZE = 10
class GraphProvider(Enum): class GraphProvider(Enum):
NEO4J = 'neo4j' NEO4J = 'neo4j'
@ -31,6 +46,93 @@ class GraphProvider(Enum):
NEPTUNE = 'neptune' NEPTUNE = 'neptune'
aoss_indices = [
{
'index_name': 'entities',
'body': {
'mappings': {
'properties': {
'uuid': {'type': 'keyword'},
'name': {'type': 'text'},
'summary': {'type': 'text'},
'group_id': {'type': 'text'},
'created_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
'name_embedding': {
'type': 'knn_vector',
'dims': EMBEDDING_DIM,
'index': True,
'similarity': 'cosine',
'method': {
'engine': 'faiss',
'space_type': 'cosinesimil',
'name': 'hnsw',
'parameters': {'ef_construction': 128, 'm': 16},
},
},
}
}
},
},
{
'index_name': 'communities',
'body': {
'mappings': {
'properties': {
'uuid': {'type': 'keyword'},
'name': {'type': 'text'},
'group_id': {'type': 'text'},
}
}
},
},
{
'index_name': 'episodes',
'body': {
'mappings': {
'properties': {
'uuid': {'type': 'keyword'},
'content': {'type': 'text'},
'source': {'type': 'text'},
'source_description': {'type': 'text'},
'group_id': {'type': 'text'},
'created_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
'valid_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
}
}
},
},
{
'index_name': 'entity_edges',
'body': {
'mappings': {
'properties': {
'uuid': {'type': 'keyword'},
'name': {'type': 'text'},
'fact': {'type': 'text'},
'group_id': {'type': 'text'},
'created_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
'valid_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
'expired_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
'invalid_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
'fact_embedding': {
'type': 'knn_vector',
'dims': EMBEDDING_DIM,
'index': True,
'similarity': 'cosine',
'method': {
'engine': 'faiss',
'space_type': 'cosinesimil',
'name': 'hnsw',
'parameters': {'ef_construction': 128, 'm': 16},
},
},
}
}
},
},
]
class GraphDriverSession(ABC): class GraphDriverSession(ABC):
provider: GraphProvider provider: GraphProvider
@ -61,6 +163,7 @@ class GraphDriver(ABC):
'' # Neo4j (default) syntax does not require a prefix for fulltext queries '' # Neo4j (default) syntax does not require a prefix for fulltext queries
) )
_database: str _database: str
aoss_client: OpenSearch | None # type: ignore
@abstractmethod @abstractmethod
def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine: def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
@ -87,3 +190,70 @@ class GraphDriver(ABC):
cloned._database = database cloned._database = database
return cloned return cloned
async def delete_all_indexes_impl(self) -> Coroutine[Any, Any, Any]:
# No matter what happens above, always return True
return self.delete_aoss_indices()
async def create_aoss_indices(self):
client = self.aoss_client
if not client:
logger.warning('No OpenSearch client found')
return
for index in aoss_indices:
alias_name = index['index_name']
# If alias already exists, skip (idempotent behavior)
if client.indices.exists_alias(name=alias_name):
continue
# Build a physical index name with timestamp
ts_suffix = datetime.utcnow().strftime('%Y%m%d%H%M%S')
physical_index_name = f'{alias_name}_{ts_suffix}'
# Create the index
client.indices.create(index=physical_index_name, body=index['body'])
# Point alias to it
client.indices.put_alias(index=physical_index_name, name=alias_name)
# Allow some time for index creation
await asyncio.sleep(60)
async def delete_aoss_indices(self):
for index in aoss_indices:
index_name = index['index_name']
client = self.aoss_client
if not client:
logger.warning('No OpenSearch client found')
return
if client.indices.exists(index=index_name):
client.indices.delete(index=index_name)
def save_to_aoss(self, name: str, data: list[dict]) -> int:
client = self.aoss_client
if not client or not helpers:
logger.warning('No OpenSearch client found')
return 0
for index in aoss_indices:
if name.lower() == index['index_name']:
to_index = []
for d in data:
item = {
'_index': name,
'_routing': d.get('group_id'), # shard routing
}
for p in index['body']['mappings']['properties']:
if p in d: # protect against missing fields
item[p] = d[p]
to_index.append(item)
success, failed = helpers.bulk(client, to_index, stats_only=True)
return success if failed == 0 else success
return 0

View file

@ -74,6 +74,7 @@ class FalkorDriverSession(GraphDriverSession):
class FalkorDriver(GraphDriver): class FalkorDriver(GraphDriver):
provider = GraphProvider.FALKORDB provider = GraphProvider.FALKORDB
aoss_client: None = None
def __init__( def __init__(
self, self,

View file

@ -92,6 +92,7 @@ SCHEMA_QUERIES = """
class KuzuDriver(GraphDriver): class KuzuDriver(GraphDriver):
provider: GraphProvider = GraphProvider.KUZU provider: GraphProvider = GraphProvider.KUZU
aoss_client: None = None
def __init__( def __init__(
self, self,

View file

@ -22,14 +22,35 @@ from neo4j import AsyncGraphDatabase, EagerResult
from typing_extensions import LiteralString from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
from graphiti_core.helpers import semaphore_gather
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try:
import boto3
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection
_HAS_OPENSEARCH = True
except ImportError:
boto3 = None
OpenSearch = None
Urllib3AWSV4SignerAuth = None
Urllib3HttpConnection = None
_HAS_OPENSEARCH = False
class Neo4jDriver(GraphDriver): class Neo4jDriver(GraphDriver):
provider = GraphProvider.NEO4J provider = GraphProvider.NEO4J
def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'neo4j'): def __init__(
self,
uri: str,
user: str | None,
password: str | None,
database: str = 'neo4j',
aoss_host: str | None = None,
aoss_port: int | None = None,
):
super().__init__() super().__init__()
self.client = AsyncGraphDatabase.driver( self.client = AsyncGraphDatabase.driver(
uri=uri, uri=uri,
@ -37,6 +58,24 @@ class Neo4jDriver(GraphDriver):
) )
self._database = database self._database = database
self.aoss_client = None
if aoss_host and aoss_port and boto3 is not None:
try:
session = boto3.Session()
self.aoss_client = OpenSearch( # type: ignore
hosts=[{'host': aoss_host, 'port': aoss_port}],
http_auth=Urllib3AWSV4SignerAuth( # type: ignore
session.get_credentials(), session.region_name, 'aoss'
),
use_ssl=True,
verify_certs=True,
connection_class=Urllib3HttpConnection,
pool_maxsize=20,
) # type: ignore
except Exception as e:
logger.warning(f'Failed to initialize OpenSearch client: {e}')
self.aoss_client = None
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult: async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult:
# Check if database_ is provided in kwargs. # Check if database_ is provided in kwargs.
# If not populated, set the value to retain backwards compatibility # If not populated, set the value to retain backwards compatibility
@ -60,7 +99,14 @@ class Neo4jDriver(GraphDriver):
async def close(self) -> None: async def close(self) -> None:
return await self.client.close() return await self.client.close()
def delete_all_indexes(self) -> Coroutine[Any, Any, EagerResult]: def delete_all_indexes(self) -> Coroutine:
if self.aoss_client:
return semaphore_gather(
self.client.execute_query(
'CALL db.indexes() YIELD name DROP INDEX name',
),
self.delete_aoss_indices(),
)
return self.client.execute_query( return self.client.execute_query(
'CALL db.indexes() YIELD name DROP INDEX name', 'CALL db.indexes() YIELD name DROP INDEX name',
) )

View file

@ -22,16 +22,21 @@ from typing import Any
import boto3 import boto3
from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider from graphiti_core.driver.driver import (
DEFAULT_SIZE,
GraphDriver,
GraphDriverSession,
GraphProvider,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_SIZE = 10
aoss_indices = [ neptune_aoss_indices = [
{ {
'index_name': 'node_name_and_summary', 'index_name': 'node_name_and_summary',
'alias_name': 'entities',
'body': { 'body': {
'mappings': { 'mappings': {
'properties': { 'properties': {
@ -49,6 +54,7 @@ aoss_indices = [
}, },
{ {
'index_name': 'community_name', 'index_name': 'community_name',
'alias_name': 'communities',
'body': { 'body': {
'mappings': { 'mappings': {
'properties': { 'properties': {
@ -65,6 +71,7 @@ aoss_indices = [
}, },
{ {
'index_name': 'episode_content', 'index_name': 'episode_content',
'alias_name': 'episodes',
'body': { 'body': {
'mappings': { 'mappings': {
'properties': { 'properties': {
@ -88,6 +95,7 @@ aoss_indices = [
}, },
{ {
'index_name': 'edge_name_and_fact', 'index_name': 'edge_name_and_fact',
'alias_name': 'facts',
'body': { 'body': {
'mappings': { 'mappings': {
'properties': { 'properties': {
@ -220,54 +228,27 @@ class NeptuneDriver(GraphDriver):
async def _delete_all_data(self) -> Any: async def _delete_all_data(self) -> Any:
return await self.execute_query('MATCH (n) DETACH DELETE n') return await self.execute_query('MATCH (n) DETACH DELETE n')
def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
return self.delete_all_indexes_impl()
async def delete_all_indexes_impl(self) -> Coroutine[Any, Any, Any]:
# No matter what happens above, always return True
return self.delete_aoss_indices()
async def create_aoss_indices(self): async def create_aoss_indices(self):
for index in aoss_indices: for index in neptune_aoss_indices:
index_name = index['index_name'] index_name = index['index_name']
client = self.aoss_client client = self.aoss_client
if not client:
raise ValueError(
'You must provide an AOSS endpoint to create an OpenSearch driver.'
)
if not client.indices.exists(index=index_name): if not client.indices.exists(index=index_name):
client.indices.create(index=index_name, body=index['body']) client.indices.create(index=index_name, body=index['body'])
alias_name = index.get('alias_name', index_name)
if not client.indices.exists_alias(name=alias_name, index=index_name):
client.indices.put_alias(index=index_name, name=alias_name)
# Sleep for 1 minute to let the index creation complete # Sleep for 1 minute to let the index creation complete
await asyncio.sleep(60) await asyncio.sleep(60)
async def delete_aoss_indices(self): def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
for index in aoss_indices: return self.delete_all_indexes_impl()
index_name = index['index_name']
client = self.aoss_client
if client.indices.exists(index=index_name):
client.indices.delete(index=index_name)
def run_aoss_query(self, name: str, query_text: str, limit: int = 10) -> dict[str, Any]:
for index in aoss_indices:
if name.lower() == index['index_name']:
index['query']['query']['multi_match']['query'] = query_text
query = {'size': limit, 'query': index['query']}
resp = self.aoss_client.search(body=query['query'], index=index['index_name'])
return resp
return {}
def save_to_aoss(self, name: str, data: list[dict]) -> int:
for index in aoss_indices:
if name.lower() == index['index_name']:
to_index = []
for d in data:
item = {'_index': name}
for p in index['body']['mappings']['properties']:
item[p] = d[p]
to_index.append(item)
success, failed = helpers.bulk(self.aoss_client, to_index, stats_only=True)
if failed > 0:
return success
else:
return 0
return 0
class NeptuneDriverSession(GraphDriverSession): class NeptuneDriverSession(GraphDriverSession):

View file

@ -255,6 +255,21 @@ class EntityEdge(Edge):
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity) MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
""" """
elif driver.aoss_client:
resp = driver.aoss_client.search(
body={
'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
'size': 1,
},
index='entity_edges',
routing=self.group_id,
)
if resp['hits']['hits']:
self.fact_embedding = resp['hits']['hits'][0]['_source']['fact_embedding']
return
else:
raise EdgeNotFoundError(self.uuid)
if driver.provider == GraphProvider.KUZU: if driver.provider == GraphProvider.KUZU:
query = """ query = """
@ -292,14 +307,14 @@ class EntityEdge(Edge):
if driver.provider == GraphProvider.KUZU: if driver.provider == GraphProvider.KUZU:
edge_data['attributes'] = json.dumps(self.attributes) edge_data['attributes'] = json.dumps(self.attributes)
result = await driver.execute_query( result = await driver.execute_query(
get_entity_edge_save_query(driver.provider), get_entity_edge_save_query(driver.provider, has_aoss=bool(driver.aoss_client)),
**edge_data, **edge_data,
) )
else: else:
edge_data.update(self.attributes or {}) edge_data.update(self.attributes or {})
if driver.provider == GraphProvider.NEPTUNE: if driver.aoss_client:
driver.save_to_aoss('edge_name_and_fact', [edge_data]) # pyright: ignore reportAttributeAccessIssue driver.save_to_aoss('entity_edges', [edge_data]) # pyright: ignore reportAttributeAccessIssue
result = await driver.execute_query( result = await driver.execute_query(
get_entity_edge_save_query(driver.provider), get_entity_edge_save_query(driver.provider),

View file

@ -14,12 +14,13 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Iterable from collections.abc import Iterable
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
EMBEDDING_DIM = 1024 EMBEDDING_DIM = int(os.getenv('EMBEDDING_DIM', 1024))
class EmbedderConfig(BaseModel): class EmbedderConfig(BaseModel):

View file

@ -60,7 +60,7 @@ EPISODIC_EDGE_RETURN = """
""" """
def get_entity_edge_save_query(provider: GraphProvider) -> str: def get_entity_edge_save_query(provider: GraphProvider, has_aoss: bool = False) -> str:
match provider: match provider:
case GraphProvider.FALKORDB: case GraphProvider.FALKORDB:
return """ return """
@ -99,17 +99,28 @@ def get_entity_edge_save_query(provider: GraphProvider) -> str:
RETURN e.uuid AS uuid RETURN e.uuid AS uuid
""" """
case _: # Neo4j case _: # Neo4j
return """ save_embedding_query = (
MATCH (source:Entity {uuid: $edge_data.source_uuid}) """WITH e CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", $edge_data.fact_embedding)"""
MATCH (target:Entity {uuid: $edge_data.target_uuid}) if not has_aoss
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target) else ''
SET e = $edge_data )
WITH e CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", $edge_data.fact_embedding) return (
(
"""
MATCH (source:Entity {uuid: $edge_data.source_uuid})
MATCH (target:Entity {uuid: $edge_data.target_uuid})
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
SET e = $edge_data
"""
+ save_embedding_query
)
+ """
RETURN e.uuid AS uuid RETURN e.uuid AS uuid
""" """
)
def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str: def get_entity_edge_save_bulk_query(provider: GraphProvider, has_aoss: bool = False) -> str:
match provider: match provider:
case GraphProvider.FALKORDB: case GraphProvider.FALKORDB:
return """ return """
@ -152,15 +163,24 @@ def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str:
RETURN e.uuid AS uuid RETURN e.uuid AS uuid
""" """
case _: case _:
return """ save_embedding_query = (
UNWIND $entity_edges AS edge 'WITH e, edge CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", edge.fact_embedding)'
MATCH (source:Entity {uuid: edge.source_node_uuid}) if not has_aoss
MATCH (target:Entity {uuid: edge.target_node_uuid}) else ''
MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target) )
SET e = edge return (
WITH e, edge CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", edge.fact_embedding) """
UNWIND $entity_edges AS edge
MATCH (source:Entity {uuid: edge.source_node_uuid})
MATCH (target:Entity {uuid: edge.target_node_uuid})
MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target)
SET e = edge
"""
+ save_embedding_query
+ """
RETURN edge.uuid AS uuid RETURN edge.uuid AS uuid
""" """
)
def get_entity_edge_return_query(provider: GraphProvider) -> str: def get_entity_edge_return_query(provider: GraphProvider) -> str:

View file

@ -126,7 +126,7 @@ EPISODIC_NODE_RETURN_NEPTUNE = """
""" """
def get_entity_node_save_query(provider: GraphProvider, labels: str) -> str: def get_entity_node_save_query(provider: GraphProvider, labels: str, has_aoss: bool = False) -> str:
match provider: match provider:
case GraphProvider.FALKORDB: case GraphProvider.FALKORDB:
return f""" return f"""
@ -161,16 +161,27 @@ def get_entity_node_save_query(provider: GraphProvider, labels: str) -> str:
RETURN n.uuid AS uuid RETURN n.uuid AS uuid
""" """
case _: case _:
return f""" save_embedding_query = (
'WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding)'
if not has_aoss
else ''
)
return (
f"""
MERGE (n:Entity {{uuid: $entity_data.uuid}}) MERGE (n:Entity {{uuid: $entity_data.uuid}})
SET n:{labels} SET n:{labels}
SET n = $entity_data SET n = $entity_data
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding) """
+ save_embedding_query
+ """
RETURN n.uuid AS uuid RETURN n.uuid AS uuid
""" """
)
def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict]) -> str | Any: def get_entity_node_save_bulk_query(
provider: GraphProvider, nodes: list[dict], has_aoss: bool = False
) -> str | Any:
match provider: match provider:
case GraphProvider.FALKORDB: case GraphProvider.FALKORDB:
queries = [] queries = []
@ -222,14 +233,23 @@ def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict])
RETURN n.uuid AS uuid RETURN n.uuid AS uuid
""" """
case _: # Neo4j case _: # Neo4j
return """ save_embedding_query = (
UNWIND $nodes AS node 'WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)'
MERGE (n:Entity {uuid: node.uuid}) if not has_aoss
SET n:$(node.labels) else ''
SET n = node )
WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding) return (
"""
UNWIND $nodes AS node
MERGE (n:Entity {uuid: node.uuid})
SET n:$(node.labels)
SET n = node
"""
+ save_embedding_query
+ """
RETURN n.uuid AS uuid RETURN n.uuid AS uuid
""" """
)
def get_entity_node_return_query(provider: GraphProvider) -> str: def get_entity_node_return_query(provider: GraphProvider) -> str:

View file

@ -273,20 +273,6 @@ class EpisodicNode(Node):
) )
async def save(self, driver: GraphDriver): async def save(self, driver: GraphDriver):
if driver.provider == GraphProvider.NEPTUNE:
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
'episode_content',
[
{
'uuid': self.uuid,
'group_id': self.group_id,
'source': self.source.value,
'content': self.content,
'source_description': self.source_description,
}
],
)
episode_args = { episode_args = {
'uuid': self.uuid, 'uuid': self.uuid,
'name': self.name, 'name': self.name,
@ -299,6 +285,12 @@ class EpisodicNode(Node):
'source': self.source.value, 'source': self.source.value,
} }
if driver.aoss_client:
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
'episodes',
[episode_args],
)
result = await driver.execute_query( result = await driver.execute_query(
get_episode_node_save_query(driver.provider), **episode_args get_episode_node_save_query(driver.provider), **episode_args
) )
@ -433,6 +425,22 @@ class EntityNode(Node):
MATCH (n:Entity {uuid: $uuid}) MATCH (n:Entity {uuid: $uuid})
RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding
""" """
elif driver.aoss_client:
resp = driver.aoss_client.search(
body={
'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
'size': 1,
},
index='entities',
routing=self.group_id,
)
if resp['hits']['hits']:
self.name_embedding = resp['hits']['hits'][0]['_source']['name_embedding']
return
else:
raise NodeNotFoundError(self.uuid)
else: else:
query: LiteralString = """ query: LiteralString = """
MATCH (n:Entity {uuid: $uuid}) MATCH (n:Entity {uuid: $uuid})
@ -470,11 +478,11 @@ class EntityNode(Node):
entity_data.update(self.attributes or {}) entity_data.update(self.attributes or {})
labels = ':'.join(self.labels + ['Entity']) labels = ':'.join(self.labels + ['Entity'])
if driver.provider == GraphProvider.NEPTUNE: if driver.aoss_client:
driver.save_to_aoss('node_name_and_summary', [entity_data]) # pyright: ignore reportAttributeAccessIssue driver.save_to_aoss('entities', [entity_data]) # pyright: ignore reportAttributeAccessIssue
result = await driver.execute_query( result = await driver.execute_query(
get_entity_node_save_query(driver.provider, labels), get_entity_node_save_query(driver.provider, labels, bool(driver.aoss_client)),
entity_data=entity_data, entity_data=entity_data,
) )
@ -570,7 +578,7 @@ class CommunityNode(Node):
async def save(self, driver: GraphDriver): async def save(self, driver: GraphDriver):
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
'community_name', 'communities',
[{'name': self.name, 'uuid': self.uuid, 'group_id': self.group_id}], [{'name': self.name, 'uuid': self.uuid, 'group_id': self.group_id}],
) )
result = await driver.execute_query( result = await driver.execute_query(

View file

@ -54,6 +54,16 @@ class SearchFilters(BaseModel):
expired_at: list[list[DateFilter]] | None = Field(default=None) expired_at: list[list[DateFilter]] | None = Field(default=None)
def cypher_to_opensearch_operator(op: ComparisonOperator) -> str:
mapping = {
ComparisonOperator.greater_than: 'gt',
ComparisonOperator.less_than: 'lt',
ComparisonOperator.greater_than_equal: 'gte',
ComparisonOperator.less_than_equal: 'lte',
}
return mapping.get(op, op.value)
def node_search_filter_query_constructor( def node_search_filter_query_constructor(
filters: SearchFilters, filters: SearchFilters,
provider: GraphProvider, provider: GraphProvider,
@ -234,3 +244,38 @@ def edge_search_filter_query_constructor(
filter_queries.append(expired_at_filter) filter_queries.append(expired_at_filter)
return filter_queries, filter_params return filter_queries, filter_params
def build_aoss_node_filters(group_ids: list[str], search_filters: SearchFilters) -> list[dict]:
filters = [{'terms': {'group_id': group_ids}}]
if search_filters.node_labels:
filters.append({'terms': {'node_labels': search_filters.node_labels}})
return filters
def build_aoss_edge_filters(group_ids: list[str], search_filters: SearchFilters) -> list[dict]:
filters: list[dict] = [{'terms': {'group_id': group_ids}}]
if search_filters.edge_types:
filters.append({'terms': {'edge_types': search_filters.edge_types}})
for field in ['valid_at', 'invalid_at', 'created_at', 'expired_at']:
ranges = getattr(search_filters, field)
if ranges:
# OR of ANDs
should_clauses = []
for and_group in ranges:
and_filters = []
for df in and_group: # df is a DateFilter
range_query = {
'range': {
field: {cypher_to_opensearch_operator(df.comparison_operator): df.date}
}
}
and_filters.append(range_query)
should_clauses.append({'bool': {'filter': and_filters}})
filters.append({'bool': {'should': should_clauses, 'minimum_should_match': 1}})
return filters

View file

@ -51,6 +51,8 @@ from graphiti_core.nodes import (
) )
from graphiti_core.search.search_filters import ( from graphiti_core.search.search_filters import (
SearchFilters, SearchFilters,
build_aoss_edge_filters,
build_aoss_node_filters,
edge_search_filter_query_constructor, edge_search_filter_query_constructor,
node_search_filter_query_constructor, node_search_filter_query_constructor,
) )
@ -200,7 +202,6 @@ async def edge_fulltext_search(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
res = driver.run_aoss_query('edge_name_and_fact', query) # pyright: ignore reportAttributeAccessIssue res = driver.run_aoss_query('edge_name_and_fact', query) # pyright: ignore reportAttributeAccessIssue
if res['hits']['total']['value'] > 0: if res['hits']['total']['value'] > 0:
# Calculate Cosine similarity then return the edge ids
input_ids = [] input_ids = []
for r in res['hits']['hits']: for r in res['hits']['hits']:
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']}) input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
@ -208,11 +209,11 @@ async def edge_fulltext_search(
# Match the edge ids and return the values # Match the edge ids and return the values
query = ( query = (
""" """
UNWIND $ids as id UNWIND $ids as id
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
WHERE e.group_id IN $group_ids WHERE e.group_id IN $group_ids
AND id(e)=id AND id(e)=id
""" """
+ filter_query + filter_query
+ """ + """
AND id(e)=id AND id(e)=id
@ -244,6 +245,31 @@ async def edge_fulltext_search(
) )
else: else:
return [] return []
elif driver.aoss_client:
route = group_ids[0] if group_ids else None
filters = build_aoss_edge_filters(group_ids or [], search_filter)
res = driver.aoss_client.search(
index='entity_edges',
routing=route,
_source=['uuid'],
query={
'bool': {
'filter': filters,
'must': [{'match': {'fact': {'query': query, 'operator': 'or'}}}],
}
},
)
if res['hits']['total']['value'] > 0:
input_uuids = {}
for r in res['hits']['hits']:
input_uuids[r['_source']['uuid']] = r['_score']
# Get edges
entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys()))
entity_edges.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
return entity_edges
else:
return []
else: else:
query = ( query = (
get_relationships_query('edge_name_and_fact', limit=limit, provider=driver.provider) get_relationships_query('edge_name_and_fact', limit=limit, provider=driver.provider)
@ -318,8 +344,8 @@ async def edge_similarity_search(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
query = ( query = (
""" """
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
""" """
+ filter_query + filter_query
+ """ + """
RETURN DISTINCT id(e) as id, e.fact_embedding as embedding RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
@ -377,6 +403,32 @@ async def edge_similarity_search(
) )
else: else:
return [] return []
elif driver.aoss_client:
route = group_ids[0] if group_ids else None
filters = build_aoss_edge_filters(group_ids or [], search_filter)
res = driver.aoss_client.search(
index='entity_edges',
routing=route,
_source=['uuid'],
knn={
'field': 'fact_embedding',
'query_vector': search_vector,
'k': limit,
'num_candidates': 1000,
},
query={'bool': {'filter': filters}},
)
if res['hits']['total']['value'] > 0:
input_uuids = {}
for r in res['hits']['hits']:
input_uuids[r['_source']['uuid']] = r['_score']
# Get edges
entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys()))
entity_edges.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
return entity_edges
else: else:
query = ( query = (
match_query match_query
@ -563,7 +615,6 @@ async def node_fulltext_search(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
res = driver.run_aoss_query('node_name_and_summary', query, limit=limit) # pyright: ignore reportAttributeAccessIssue res = driver.run_aoss_query('node_name_and_summary', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
if res['hits']['total']['value'] > 0: if res['hits']['total']['value'] > 0:
# Calculate Cosine similarity then return the edge ids
input_ids = [] input_ids = []
for r in res['hits']['hits']: for r in res['hits']['hits']:
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']}) input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
@ -571,11 +622,11 @@ async def node_fulltext_search(
# Match the edge ides and return the values # Match the edge ides and return the values
query = ( query = (
""" """
UNWIND $ids as i UNWIND $ids as i
MATCH (n:Entity) MATCH (n:Entity)
WHERE n.uuid=i.id WHERE n.uuid=i.id
RETURN RETURN
""" """
+ get_entity_node_return_query(driver.provider) + get_entity_node_return_query(driver.provider)
+ """ + """
ORDER BY i.score DESC ORDER BY i.score DESC
@ -592,6 +643,41 @@ async def node_fulltext_search(
) )
else: else:
return [] return []
elif driver.aoss_client:
route = group_ids[0] if group_ids else None
filters = build_aoss_node_filters(group_ids or [], search_filter)
res = driver.aoss_client.search(
'entities',
routing=route,
_source=['uuid'],
query={
'bool': {
'filter': filters,
'must': [
{
'multi_match': {
'query': query,
'field': ['name', 'summary'],
'operator': 'or',
}
}
],
}
},
limit=limit,
)
if res['hits']['total']['value'] > 0:
input_uuids = {}
for r in res['hits']['hits']:
input_uuids[r['_source']['uuid']] = r['_score']
# Get nodes
entities = await EntityNode.get_by_uuids(driver, list(input_uuids.keys()))
entities.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
return entities
else:
return []
else: else:
query = ( query = (
get_nodes_query( get_nodes_query(
@ -648,8 +734,8 @@ async def node_similarity_search(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
query = ( query = (
""" """
MATCH (n:Entity) MATCH (n:Entity)
""" """
+ filter_query + filter_query
+ """ + """
RETURN DISTINCT id(n) as id, n.name_embedding as embedding RETURN DISTINCT id(n) as id, n.name_embedding as embedding
@ -678,11 +764,11 @@ async def node_similarity_search(
# Match the edge ides and return the values # Match the edge ides and return the values
query = ( query = (
""" """
UNWIND $ids as i UNWIND $ids as i
MATCH (n:Entity) MATCH (n:Entity)
WHERE id(n)=i.id WHERE id(n)=i.id
RETURN RETURN
""" """
+ get_entity_node_return_query(driver.provider) + get_entity_node_return_query(driver.provider)
+ """ + """
ORDER BY i.score DESC ORDER BY i.score DESC
@ -700,11 +786,36 @@ async def node_similarity_search(
) )
else: else:
return [] return []
elif driver.aoss_client:
route = group_ids[0] if group_ids else None
filters = build_aoss_node_filters(group_ids or [], search_filter)
res = driver.aoss_client.search(
index='entities',
routing=route,
_source=['uuid'],
knn={
'field': 'fact_embedding',
'query_vector': search_vector,
'k': limit,
'num_candidates': 1000,
},
query={'bool': {'filter': filters}},
)
if res['hits']['total']['value'] > 0:
input_uuids = {}
for r in res['hits']['hits']:
input_uuids[r['_source']['uuid']] = r['_score']
# Get edges
entity_nodes = await EntityNode.get_by_uuids(driver, list(input_uuids.keys()))
entity_nodes.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
return entity_nodes
else: else:
query = ( query = (
""" """
MATCH (n:Entity) MATCH (n:Entity)
""" """
+ filter_query + filter_query
+ """ + """
WITH n, """ WITH n, """
@ -843,7 +954,6 @@ async def episode_fulltext_search(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
res = driver.run_aoss_query('episode_content', query, limit=limit) # pyright: ignore reportAttributeAccessIssue res = driver.run_aoss_query('episode_content', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
if res['hits']['total']['value'] > 0: if res['hits']['total']['value'] > 0:
# Calculate Cosine similarity then return the edge ids
input_ids = [] input_ids = []
for r in res['hits']['hits']: for r in res['hits']['hits']:
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']}) input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
@ -852,7 +962,7 @@ async def episode_fulltext_search(
query = """ query = """
UNWIND $ids as i UNWIND $ids as i
MATCH (e:Episodic) MATCH (e:Episodic)
WHERE e.uuid=i.id WHERE e.uuid=i.uuid
RETURN RETURN
e.content AS content, e.content AS content,
e.created_at AS created_at, e.created_at AS created_at,
@ -876,6 +986,40 @@ async def episode_fulltext_search(
) )
else: else:
return [] return []
elif driver.aoss_client:
route = group_ids[0] if group_ids else None
res = driver.aoss_client.search(
'episodes',
routing=route,
_source=['uuid'],
query={
'bool': {
'filter': {'terms': group_ids},
'must': [
{
'multi_match': {
'query': query,
'field': ['name', 'content'],
'operator': 'or',
}
}
],
}
},
limit=limit,
)
if res['hits']['total']['value'] > 0:
input_uuids = {}
for r in res['hits']['hits']:
input_uuids[r['_source']['uuid']] = r['_score']
# Get nodes
episodes = await EpisodicNode.get_by_uuids(driver, list(input_uuids.keys()))
episodes.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
return episodes
else:
return []
else: else:
query = ( query = (
get_nodes_query('episode_content', '$query', limit=limit, provider=driver.provider) get_nodes_query('episode_content', '$query', limit=limit, provider=driver.provider)
@ -1003,8 +1147,8 @@ async def community_similarity_search(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
query = ( query = (
""" """
MATCH (n:Community) MATCH (n:Community)
""" """
+ group_filter_query + group_filter_query
+ """ + """
RETURN DISTINCT id(n) as id, n.name_embedding as embedding RETURN DISTINCT id(n) as id, n.name_embedding as embedding
@ -1063,8 +1207,8 @@ async def community_similarity_search(
query = ( query = (
""" """
MATCH (c:Community) MATCH (c:Community)
""" """
+ group_filter_query + group_filter_query
+ """ + """
WITH c, WITH c,
@ -1206,9 +1350,9 @@ async def get_relevant_nodes(
# FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver. # FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
query = ( query = (
""" """
UNWIND $nodes AS node UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id}) MATCH (n:Entity {group_id: $group_id})
""" """
+ filter_query + filter_query
+ """ + """
WITH node, n, """ WITH node, n, """
@ -1253,9 +1397,9 @@ async def get_relevant_nodes(
else: else:
query = ( query = (
""" """
UNWIND $nodes AS node UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id}) MATCH (n:Entity {group_id: $group_id})
""" """
+ filter_query + filter_query
+ """ + """
WITH node, n, """ WITH node, n, """
@ -1344,9 +1488,9 @@ async def get_relevant_edges(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
""" """
+ filter_query + filter_query
+ """ + """
WITH e, edge WITH e, edge
@ -1416,9 +1560,9 @@ async def get_relevant_edges(
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid}) MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
""" """
+ filter_query + filter_query
+ """ + """
WITH e, edge, n, m, """ WITH e, edge, n, m, """
@ -1454,9 +1598,9 @@ async def get_relevant_edges(
else: else:
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
""" """
+ filter_query + filter_query
+ """ + """
WITH e, edge, """ WITH e, edge, """
@ -1529,10 +1673,10 @@ async def get_edge_invalidation_candidates(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
""" """
+ filter_query + filter_query
+ """ + """
WITH e, edge WITH e, edge
@ -1602,10 +1746,10 @@ async def get_edge_invalidation_candidates(
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity) MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]) WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
""" """
+ filter_query + filter_query
+ """ + """
WITH edge, e, n, m, """ WITH edge, e, n, m, """
@ -1641,10 +1785,10 @@ async def get_edge_invalidation_candidates(
else: else:
query = ( query = (
""" """
UNWIND $edges AS edge UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
""" """
+ filter_query + filter_query
+ """ + """
WITH edge, e, """ WITH edge, e, """

View file

@ -187,12 +187,25 @@ async def add_nodes_and_edges_bulk_tx(
await tx.run(episodic_edge_query, **edge.model_dump()) await tx.run(episodic_edge_query, **edge.model_dump())
else: else:
await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes) await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
await tx.run(get_entity_node_save_bulk_query(driver.provider, nodes), nodes=nodes) await tx.run(
get_entity_node_save_bulk_query(driver.provider, nodes),
nodes=nodes,
has_aoss=bool(driver.aoss_client),
)
await tx.run( await tx.run(
get_episodic_edge_save_bulk_query(driver.provider), get_episodic_edge_save_bulk_query(driver.provider),
episodic_edges=[edge.model_dump() for edge in episodic_edges], episodic_edges=[edge.model_dump() for edge in episodic_edges],
) )
await tx.run(get_entity_edge_save_bulk_query(driver.provider), entity_edges=edges) await tx.run(
get_entity_edge_save_bulk_query(driver.provider),
entity_edges=edges,
has_aoss=bool(driver.aoss_client),
)
if driver.aoss_client:
driver.save_to_aoss('episodes', episodes)
driver.save_to_aoss('entities', nodes)
driver.save_to_aoss('entity_edges', edges)
async def extract_nodes_and_edges_bulk( async def extract_nodes_and_edges_bulk(

View file

@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bool = False): async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bool = False):
if driver.provider == GraphProvider.NEPTUNE: if driver.aoss_client:
await driver.create_aoss_indices() # pyright: ignore[reportAttributeAccessIssue] await driver.create_aoss_indices() # pyright: ignore[reportAttributeAccessIssue]
return return
if delete_existing: if delete_existing:
@ -56,7 +56,9 @@ async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bo
range_indices: list[LiteralString] = get_range_indices(driver.provider) range_indices: list[LiteralString] = get_range_indices(driver.provider)
fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider) # Don't create fulltext indices if OpenSearch is being used
if not driver.aoss_client:
fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider)
if driver.provider == GraphProvider.KUZU: if driver.provider == GraphProvider.KUZU:
# Skip creating fulltext indices if they already exist. Need to do this manually # Skip creating fulltext indices if they already exist. Need to do this manually
@ -149,9 +151,9 @@ async def retrieve_episodes(
query: LiteralString = ( query: LiteralString = (
""" """
MATCH (e:Episodic) MATCH (e:Episodic)
WHERE e.valid_at <= $reference_time WHERE e.valid_at <= $reference_time
""" """
+ query_filter + query_filter
+ """ + """
RETURN RETURN

View file

@ -1,7 +1,7 @@
[project] [project]
name = "graphiti-core" name = "graphiti-core"
description = "A temporal graph building library" description = "A temporal graph building library"
version = "0.20.4" version = "0.21.0pre1"
authors = [ authors = [
{ name = "Paul Paliychuk", email = "paul@getzep.com" }, { name = "Paul Paliychuk", email = "paul@getzep.com" },
{ name = "Preston Rasmussen", email = "preston@getzep.com" }, { name = "Preston Rasmussen", email = "preston@getzep.com" },
@ -32,6 +32,7 @@ google-genai = ["google-genai>=1.8.0"]
kuzu = ["kuzu>=0.11.2"] kuzu = ["kuzu>=0.11.2"]
falkordb = ["falkordb>=1.1.2,<2.0.0"] falkordb = ["falkordb>=1.1.2,<2.0.0"]
voyageai = ["voyageai>=0.2.3"] voyageai = ["voyageai>=0.2.3"]
neo4j-opensearch = ["boto3>=1.39.16", "opensearch-py>=3.0.0"]
sentence-transformers = ["sentence-transformers>=3.2.1"] sentence-transformers = ["sentence-transformers>=3.2.1"]
neptune = ["langchain-aws>=0.2.29", "opensearch-py>=3.0.0", "boto3>=1.39.16"] neptune = ["langchain-aws>=0.2.29", "opensearch-py>=3.0.0", "boto3>=1.39.16"]
dev = [ dev = [

10
uv.lock generated
View file

@ -783,7 +783,7 @@ wheels = [
[[package]] [[package]]
name = "graphiti-core" name = "graphiti-core"
version = "0.20.4" version = "0.21.0rc1"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "diskcache" }, { name = "diskcache" },
@ -835,6 +835,10 @@ groq = [
kuzu = [ kuzu = [
{ name = "kuzu" }, { name = "kuzu" },
] ]
neo4j-opensearch = [
{ name = "boto3" },
{ name = "opensearch-py" },
]
neptune = [ neptune = [
{ name = "boto3" }, { name = "boto3" },
{ name = "langchain-aws" }, { name = "langchain-aws" },
@ -851,6 +855,7 @@ voyageai = [
requires-dist = [ requires-dist = [
{ name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.49.0" }, { name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.49.0" },
{ name = "anthropic", marker = "extra == 'dev'", specifier = ">=0.49.0" }, { name = "anthropic", marker = "extra == 'dev'", specifier = ">=0.49.0" },
{ name = "boto3", marker = "extra == 'neo4j-opensearch'", specifier = ">=1.39.16" },
{ name = "boto3", marker = "extra == 'neptune'", specifier = ">=1.39.16" }, { name = "boto3", marker = "extra == 'neptune'", specifier = ">=1.39.16" },
{ name = "diskcache", specifier = ">=5.6.3" }, { name = "diskcache", specifier = ">=5.6.3" },
{ name = "diskcache-stubs", marker = "extra == 'dev'", specifier = ">=5.6.3.6.20240818" }, { name = "diskcache-stubs", marker = "extra == 'dev'", specifier = ">=5.6.3.6.20240818" },
@ -872,6 +877,7 @@ requires-dist = [
{ name = "neo4j", specifier = ">=5.26.0" }, { name = "neo4j", specifier = ">=5.26.0" },
{ name = "numpy", specifier = ">=1.0.0" }, { name = "numpy", specifier = ">=1.0.0" },
{ name = "openai", specifier = ">=1.91.0" }, { name = "openai", specifier = ">=1.91.0" },
{ name = "opensearch-py", marker = "extra == 'neo4j-opensearch'", specifier = ">=3.0.0" },
{ name = "opensearch-py", marker = "extra == 'neptune'", specifier = ">=3.0.0" }, { name = "opensearch-py", marker = "extra == 'neptune'", specifier = ">=3.0.0" },
{ name = "posthog", specifier = ">=3.0.0" }, { name = "posthog", specifier = ">=3.0.0" },
{ name = "pydantic", specifier = ">=2.11.5" }, { name = "pydantic", specifier = ">=2.11.5" },
@ -888,7 +894,7 @@ requires-dist = [
{ name = "voyageai", marker = "extra == 'dev'", specifier = ">=0.2.3" }, { name = "voyageai", marker = "extra == 'dev'", specifier = ">=0.2.3" },
{ name = "voyageai", marker = "extra == 'voyageai'", specifier = ">=0.2.3" }, { name = "voyageai", marker = "extra == 'voyageai'", specifier = ">=0.2.3" },
] ]
provides-extras = ["anthropic", "groq", "google-genai", "kuzu", "falkordb", "voyageai", "sentence-transformers", "neptune", "dev"] provides-extras = ["anthropic", "groq", "google-genai", "kuzu", "falkordb", "voyageai", "neo4j-opensearch", "sentence-transformers", "neptune", "dev"]
[[package]] [[package]]
name = "groq" name = "groq"