Amazon Neptune Support (#793)

* Rebased Neptune changes based on significant rework done

* Updated the README documentation

* Fixed linting and formatting

* Update README.md

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>

* Update graphiti_core/driver/neptune_driver.py

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>

* Update README.md

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>

* Addressed feedback from code review

* Updated the README documentation for clarity

* Updated the README and neptune_driver based on PR feedback

* Update node_db_queries.py

---------

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
Co-authored-by: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com>
This commit is contained in:
bechbd 2025-08-20 06:56:03 -08:00 committed by GitHub
parent 9c1e1ad7ef
commit ef56dc779a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 3805 additions and 2460 deletions

View file

@ -105,7 +105,7 @@ Graphiti is specifically designed to address the challenges of dynamic and frequ
Requirements:
- Python 3.10 or higher
- Neo4j 5.26 / FalkorDB 1.1.2 or higher (serves as the embeddings storage backend)
- Neo4j 5.26 / FalkorDB 1.1.2 / Amazon Neptune Database Cluster or Neptune Analytics Graph + Amazon OpenSearch Serverless collection (serves as the full text search backend)
- OpenAI API key (Graphiti defaults to OpenAI for LLM inference and embedding)
> [!IMPORTANT]
@ -148,6 +148,17 @@ pip install graphiti-core[falkordb]
uv add graphiti-core[falkordb]
```
### Installing with Amazon Neptune Support
If you plan to use Amazon Neptune as your graph database backend, install with the Amazon Neptune extra:
```bash
pip install graphiti-core[neptune]
# or with uv
uv add graphiti-core[neptune]
```
### You can also install optional LLM providers as extras:
```bash
@ -165,6 +176,9 @@ pip install graphiti-core[anthropic,groq,google-genai]
# Install with FalkorDB and LLM providers
pip install graphiti-core[falkordb,anthropic,google-genai]
# Install with Amazon Neptune
pip install graphiti-core[neptune]
```
## Default to Low Concurrency; LLM Provider 429 Rate Limit Errors
@ -184,7 +198,7 @@ If your LLM provider allows higher throughput, you can increase `SEMAPHORE_LIMIT
For a complete working example, see the [Quickstart Example](./examples/quickstart/README.md) in the examples directory. The quickstart demonstrates:
1. Connecting to a Neo4j or FalkorDB database
1. Connecting to a Neo4j, Amazon Neptune, or FalkorDB database
2. Initializing Graphiti indices and constraints
3. Adding episodes to the graph (both text and structured JSON)
4. Searching for relationships (edges) using hybrid search
@ -267,6 +281,26 @@ driver = FalkorDriver(
graphiti = Graphiti(graph_driver=driver)
```
#### Amazon Neptune
```python
from graphiti_core import Graphiti
from graphiti_core.driver.neptune_driver import NeptuneDriver
# Create a FalkorDB driver with custom database name
driver = NeptuneDriver(
host=<NEPTUNE ENDPOINT>,
aoss_host=<Amazon OpenSearch Serverless Host>,
port=<PORT> # Optional, defaults to 8182,
aoss_port=<PORT> # Optional, defaults to 443
)
driver = NeptuneDriver(host=neptune_uri, aoss_host=aoss_host, port=neptune_port)
# Pass the driver to Graphiti
graphiti = Graphiti(graph_driver=driver)
```
### Performance Configuration
@ -458,7 +492,7 @@ When you initialize a Graphiti instance, we collect:
- **Graphiti version**: The version you're using
- **Configuration choices**:
- LLM provider type (OpenAI, Azure, Anthropic, etc.)
- Database backend (Neo4j, FalkorDB)
- Database backend (Neo4j, FalkorDB, Amazon Neptune Database or Neptune Analytics)
- Embedder provider (OpenAI, Azure, Voyage, etc.)
### What We Don't Collect

View file

@ -18,6 +18,8 @@ This example demonstrates the basic functionality of Graphiti, including:
- A local DBMS created and started in Neo4j Desktop
- **For FalkorDB**:
- FalkorDB server running (see [FalkorDB documentation](https://falkordb.com/docs/) for setup)
- **For Amazon Neptune**:
- Amazon server running (see [Amazon Neptune documentation](https://aws.amazon.com/neptune/developer-resources/) for setup)
## Setup Instructions
@ -42,9 +44,19 @@ export NEO4J_PASSWORD=password
# Optional FalkorDB connection parameters (defaults shown)
export FALKORDB_URI=falkor://localhost:6379
# Optional Amazon Neptune connection parameters
NEPTUNE_HOST=your_neptune_host
NEPTUNE_PORT=your_port_or_8182
AOSS_HOST=your_aoss_host
AOSS_PORT=your_port_or_443
# To use a different database, modify the driver constructor in the script
```
TIP: For Amazon Neptune host string please use the following formats
* For Neptune Database: `neptune-db://<cluster endpoint>`
* For Neptune Analytics: `neptune-graph://<graph identifier>`
3. Run the example:
```bash
@ -52,11 +64,14 @@ python quickstart_neo4j.py
# For FalkorDB
python quickstart_falkordb.py
# For Amazon Neptune
python quickstart_neptune.py
```
## What This Example Demonstrates
- **Graph Initialization**: Setting up the Graphiti indices and constraints in Neo4j or FalkorDB
- **Graph Initialization**: Setting up the Graphiti indices and constraints in Neo4j, Amazon Neptune, or FalkorDB
- **Adding Episodes**: Adding text content that will be analyzed and converted into knowledge graph nodes and edges
- **Edge Search Functionality**: Performing hybrid searches that combine semantic similarity and BM25 retrieval to find relationships (edges)
- **Graph-Aware Search**: Using the source node UUID from the top search result to rerank additional search results based on graph distance

View file

@ -27,6 +27,7 @@ logger = logging.getLogger(__name__)
class GraphProvider(Enum):
NEO4J = 'neo4j'
FALKORDB = 'falkordb'
NEPTUNE = 'neptune'
class GraphDriverSession(ABC):

View file

@ -0,0 +1,299 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import asyncio
import datetime
import logging
from collections.abc import Coroutine
from typing import Any
import boto3
from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
logger = logging.getLogger(__name__)
DEFAULT_SIZE = 10
aoss_indices = [
{
'index_name': 'node_name_and_summary',
'body': {
'mappings': {
'properties': {
'uuid': {'type': 'keyword'},
'name': {'type': 'text'},
'summary': {'type': 'text'},
'group_id': {'type': 'text'},
}
}
},
'query': {
'query': {'multi_match': {'query': '', 'fields': ['name', 'summary', 'group_id']}},
'size': DEFAULT_SIZE,
},
},
{
'index_name': 'community_name',
'body': {
'mappings': {
'properties': {
'uuid': {'type': 'keyword'},
'name': {'type': 'text'},
'group_id': {'type': 'text'},
}
}
},
'query': {
'query': {'multi_match': {'query': '', 'fields': ['name', 'group_id']}},
'size': DEFAULT_SIZE,
},
},
{
'index_name': 'episode_content',
'body': {
'mappings': {
'properties': {
'uuid': {'type': 'keyword'},
'content': {'type': 'text'},
'source': {'type': 'text'},
'source_description': {'type': 'text'},
'group_id': {'type': 'text'},
}
}
},
'query': {
'query': {
'multi_match': {
'query': '',
'fields': ['content', 'source', 'source_description', 'group_id'],
}
},
'size': DEFAULT_SIZE,
},
},
{
'index_name': 'edge_name_and_fact',
'body': {
'mappings': {
'properties': {
'uuid': {'type': 'keyword'},
'name': {'type': 'text'},
'fact': {'type': 'text'},
'group_id': {'type': 'text'},
}
}
},
'query': {
'query': {'multi_match': {'query': '', 'fields': ['name', 'fact', 'group_id']}},
'size': DEFAULT_SIZE,
},
},
]
class NeptuneDriver(GraphDriver):
provider: GraphProvider = GraphProvider.NEPTUNE
def __init__(self, host: str, aoss_host: str, port: int = 8182, aoss_port: int = 443):
"""This initializes a NeptuneDriver for use with Neptune as a backend
Args:
host (str): The Neptune Database or Neptune Analytics host
aoss_host (str): The OpenSearch host value
port (int, optional): The Neptune Database port, ignored for Neptune Analytics. Defaults to 8182.
aoss_port (int, optional): The OpenSearch port. Defaults to 443.
"""
if not host:
raise ValueError('You must provide an endpoint to create a NeptuneDriver')
if host.startswith('neptune-db://'):
# This is a Neptune Database Cluster
endpoint = host.replace('neptune-db://', '')
self.client = NeptuneGraph(endpoint, port)
logger.debug('Creating Neptune Database session for %s', host)
elif host.startswith('neptune-graph://'):
# This is a Neptune Analytics Graph
graphId = host.replace('neptune-graph://', '')
self.client = NeptuneAnalyticsGraph(graphId)
logger.debug('Creating Neptune Graph session for %s', host)
else:
raise ValueError(
'You must provide an endpoint to create a NeptuneDriver as either neptune-db://<endpoint> or neptune-graph://<graphid>'
)
if not aoss_host:
raise ValueError('You must provide an AOSS endpoint to create an OpenSearch driver.')
session = boto3.Session()
self.aoss_client = OpenSearch(
hosts=[{'host': aoss_host, 'port': aoss_port}],
http_auth=Urllib3AWSV4SignerAuth(
session.get_credentials(), session.region_name, 'aoss'
),
use_ssl=True,
verify_certs=True,
connection_class=Urllib3HttpConnection,
pool_maxsize=20,
)
def _sanitize_parameters(self, query, params: dict):
if isinstance(query, list):
queries = []
for q in query:
queries.append(self._sanitize_parameters(q, params))
return queries
else:
for k, v in params.items():
if isinstance(v, datetime.datetime):
params[k] = v.isoformat()
elif isinstance(v, list):
# Handle lists that might contain datetime objects
for i, item in enumerate(v):
if isinstance(item, datetime.datetime):
v[i] = item.isoformat()
query = str(query).replace(f'${k}', f'datetime(${k})')
if isinstance(item, dict):
query = self._sanitize_parameters(query, v[i])
# If the list contains datetime objects, we need to wrap each element with datetime()
if any(isinstance(item, str) and 'T' in item for item in v):
# Create a new list expression with datetime() wrapped around each element
datetime_list = (
'['
+ ', '.join(
f'datetime("{item}")'
if isinstance(item, str) and 'T' in item
else repr(item)
for item in v
)
+ ']'
)
query = str(query).replace(f'${k}', datetime_list)
elif isinstance(v, dict):
query = self._sanitize_parameters(query, v)
return query
async def execute_query(
self, cypher_query_, **kwargs: Any
) -> tuple[dict[str, Any], None, None]:
params = dict(kwargs)
if isinstance(cypher_query_, list):
for q in cypher_query_:
result, _, _ = self._run_query(q[0], q[1])
return result, None, None
else:
return self._run_query(cypher_query_, params)
def _run_query(self, cypher_query_, params):
cypher_query_ = str(self._sanitize_parameters(cypher_query_, params))
try:
result = self.client.query(cypher_query_, params=params)
except Exception as e:
logger.error('Query: %s', cypher_query_)
logger.error('Parameters: %s', params)
logger.error('Error executing query: %s', e)
raise e
return result, None, None
def session(self, database: str | None = None) -> GraphDriverSession:
return NeptuneDriverSession(driver=self)
async def close(self) -> None:
return self.client.client.close()
async def _delete_all_data(self) -> Any:
return await self.execute_query('MATCH (n) DETACH DELETE n')
def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
return self.delete_all_indexes_impl()
async def delete_all_indexes_impl(self) -> Coroutine[Any, Any, Any]:
# No matter what happens above, always return True
return self.delete_aoss_indices()
async def create_aoss_indices(self):
for index in aoss_indices:
index_name = index['index_name']
client = self.aoss_client
if not client.indices.exists(index=index_name):
client.indices.create(index=index_name, body=index['body'])
# Sleep for 1 minute to let the index creation complete
await asyncio.sleep(60)
async def delete_aoss_indices(self):
for index in aoss_indices:
index_name = index['index_name']
client = self.aoss_client
if client.indices.exists(index=index_name):
client.indices.delete(index=index_name)
def run_aoss_query(self, name: str, query_text: str, limit: int = 10) -> dict[str, Any]:
for index in aoss_indices:
if name.lower() == index['index_name']:
index['query']['query']['multi_match']['query'] = query_text
query = {'size': limit, 'query': index['query']}
resp = self.aoss_client.search(body=query['query'], index=index['index_name'])
return resp
return {}
def save_to_aoss(self, name: str, data: list[dict]) -> int:
for index in aoss_indices:
if name.lower() == index['index_name']:
to_index = []
for d in data:
item = {'_index': name}
for p in index['body']['mappings']['properties']:
item[p] = d[p]
to_index.append(item)
success, failed = helpers.bulk(self.aoss_client, to_index, stats_only=True)
if failed > 0:
return success
else:
return 0
return 0
class NeptuneDriverSession(GraphDriverSession):
def __init__(self, driver: NeptuneDriver): # type: ignore[reportUnknownArgumentType]
self.driver = driver
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
# No cleanup needed for Neptune, but method must exist
pass
async def close(self):
# No explicit close needed for Neptune, but method must exist
pass
async def execute_write(self, func, *args, **kwargs):
# Directly await the provided async function with `self` as the transaction/session
return await func(self, *args, **kwargs)
async def run(self, query: str | list, **kwargs: Any) -> Any:
if isinstance(query, list):
res = None
for q in query:
res = await self.driver.execute_query(q, **kwargs)
return res
else:
return await self.driver.execute_query(str(query), **kwargs)

View file

@ -24,13 +24,14 @@ from uuid import uuid4
from pydantic import BaseModel, Field
from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.driver.driver import GraphDriver, GraphProvider
from graphiti_core.embedder import EmbedderClient
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
from graphiti_core.helpers import parse_db_date
from graphiti_core.models.edges.edge_db_queries import (
COMMUNITY_EDGE_RETURN,
ENTITY_EDGE_RETURN,
ENTITY_EDGE_RETURN_NEPTUNE,
EPISODIC_EDGE_RETURN,
EPISODIC_EDGE_SAVE,
get_community_edge_save_query,
@ -214,11 +215,19 @@ class EntityEdge(Edge):
return self.fact_embedding
async def load_fact_embedding(self, driver: GraphDriver):
records, _, _ = await driver.execute_query(
if driver.provider == GraphProvider.NEPTUNE:
query: LiteralString = """
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
"""
else:
query: LiteralString = """
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
RETURN e.fact_embedding AS fact_embedding
""",
"""
records, _, _ = await driver.execute_query(
query,
uuid=self.uuid,
routing_='r',
)
@ -246,6 +255,9 @@ class EntityEdge(Edge):
edge_data.update(self.attributes or {})
if driver.provider == GraphProvider.NEPTUNE:
driver.save_to_aoss('edge_name_and_fact', [edge_data]) # pyright: ignore reportAttributeAccessIssue
result = await driver.execute_query(
get_entity_edge_save_query(driver.provider),
edge_data=edge_data,
@ -262,7 +274,11 @@ class EntityEdge(Edge):
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
RETURN
"""
+ ENTITY_EDGE_RETURN,
+ (
ENTITY_EDGE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else ENTITY_EDGE_RETURN
),
uuid=uuid,
routing_='r',
)
@ -284,7 +300,11 @@ class EntityEdge(Edge):
WHERE e.uuid IN $uuids
RETURN
"""
+ ENTITY_EDGE_RETURN,
+ (
ENTITY_EDGE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else ENTITY_EDGE_RETURN
),
uuids=uuids,
routing_='r',
)
@ -321,7 +341,11 @@ class EntityEdge(Edge):
+ """
RETURN
"""
+ ENTITY_EDGE_RETURN
+ (
ENTITY_EDGE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else ENTITY_EDGE_RETURN
)
+ with_embeddings_query
+ """
ORDER BY e.uuid DESC
@ -346,7 +370,11 @@ class EntityEdge(Edge):
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
RETURN
"""
+ ENTITY_EDGE_RETURN,
+ (
ENTITY_EDGE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else ENTITY_EDGE_RETURN
),
node_uuid=node_uuid,
routing_='r',
)

View file

@ -43,47 +43,70 @@ EPISODIC_EDGE_RETURN = """
def get_entity_edge_save_query(provider: GraphProvider) -> str:
if provider == GraphProvider.FALKORDB:
return """
MATCH (source:Entity {uuid: $edge_data.source_uuid})
MATCH (target:Entity {uuid: $edge_data.target_uuid})
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
SET e = $edge_data
RETURN e.uuid AS uuid
"""
return """
MATCH (source:Entity {uuid: $edge_data.source_uuid})
MATCH (target:Entity {uuid: $edge_data.target_uuid})
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
SET e = $edge_data
WITH e CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", $edge_data.fact_embedding)
RETURN e.uuid AS uuid
"""
match provider:
case GraphProvider.FALKORDB:
return """
MATCH (source:Entity {uuid: $edge_data.source_uuid})
MATCH (target:Entity {uuid: $edge_data.target_uuid})
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
SET e = $edge_data
RETURN e.uuid AS uuid
"""
case GraphProvider.NEPTUNE:
return """
MATCH (source:Entity {uuid: $edge_data.source_uuid})
MATCH (target:Entity {uuid: $edge_data.target_uuid})
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
SET e = removeKeyFromMap(removeKeyFromMap($edge_data, "fact_embedding"), "episodes")
SET e.fact_embedding = join([x IN coalesce($edge_data.fact_embedding, []) | toString(x) ], ",")
SET e.episodes = join($edge_data.episodes, ",")
RETURN $edge_data.uuid AS uuid
"""
case _: # Neo4j
return """
MATCH (source:Entity {uuid: $edge_data.source_uuid})
MATCH (target:Entity {uuid: $edge_data.target_uuid})
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
SET e = $edge_data
WITH e CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", $edge_data.fact_embedding)
RETURN e.uuid AS uuid
"""
def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str:
if provider == GraphProvider.FALKORDB:
return """
UNWIND $entity_edges AS edge
MATCH (source:Entity {uuid: edge.source_node_uuid})
MATCH (target:Entity {uuid: edge.target_node_uuid})
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at, fact_embedding: vecf32(edge.fact_embedding)}
WITH r, edge
RETURN edge.uuid AS uuid
"""
return """
UNWIND $entity_edges AS edge
MATCH (source:Entity {uuid: edge.source_node_uuid})
MATCH (target:Entity {uuid: edge.target_node_uuid})
MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target)
SET e = edge
WITH e, edge CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", edge.fact_embedding)
RETURN edge.uuid AS uuid
"""
match provider:
case GraphProvider.FALKORDB:
return """
UNWIND $entity_edges AS edge
MATCH (source:Entity {uuid: edge.source_node_uuid})
MATCH (target:Entity {uuid: edge.target_node_uuid})
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at, fact_embedding: vecf32(edge.fact_embedding)}
WITH r, edge
RETURN edge.uuid AS uuid
"""
case GraphProvider.NEPTUNE:
return """
UNWIND $entity_edges AS edge
MATCH (source:Entity {uuid: edge.source_node_uuid})
MATCH (target:Entity {uuid: edge.target_node_uuid})
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
SET r = removeKeyFromMap(removeKeyFromMap(edge, "fact_embedding"), "episodes")
SET r.fact_embedding = join([x IN coalesce(edge.fact_embedding, []) | toString(x) ], ",")
SET r.episodes = join(edge.episodes, ",")
RETURN edge.uuid AS uuid
"""
case _:
return """
UNWIND $entity_edges AS edge
MATCH (source:Entity {uuid: edge.source_node_uuid})
MATCH (target:Entity {uuid: edge.target_node_uuid})
MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target)
SET e = edge
WITH e, edge CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", edge.fact_embedding)
RETURN edge.uuid AS uuid
"""
ENTITY_EDGE_RETURN = """
@ -101,24 +124,51 @@ ENTITY_EDGE_RETURN = """
properties(e) AS attributes
"""
ENTITY_EDGE_RETURN_NEPTUNE = """
e.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
e.group_id AS group_id,
e.name AS name,
e.fact AS fact,
split(e.episodes, ',') AS episodes,
e.created_at AS created_at,
e.expired_at AS expired_at,
e.valid_at AS valid_at,
e.invalid_at AS invalid_at,
properties(e) AS attributes
"""
def get_community_edge_save_query(provider: GraphProvider) -> str:
if provider == GraphProvider.FALKORDB:
return """
MATCH (community:Community {uuid: $community_uuid})
MATCH (node {uuid: $entity_uuid})
MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
RETURN e.uuid AS uuid
"""
return """
MATCH (community:Community {uuid: $community_uuid})
MATCH (node:Entity | Community {uuid: $entity_uuid})
MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
RETURN e.uuid AS uuid
"""
match provider:
case GraphProvider.FALKORDB:
return """
MATCH (community:Community {uuid: $community_uuid})
MATCH (node {uuid: $entity_uuid})
MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
RETURN e.uuid AS uuid
"""
case GraphProvider.NEPTUNE:
return """
MATCH (community:Community {uuid: $community_uuid})
MATCH (node {uuid: $entity_uuid})
WHERE node:Entity OR node:Community
MERGE (community)-[r:HAS_MEMBER {uuid: $uuid}]->(node)
SET r.uuid= $uuid
SET r.group_id= $group_id
SET r.created_at= $created_at
RETURN r.uuid AS uuid
"""
case _: # Neo4j
return """
MATCH (community:Community {uuid: $community_uuid})
MATCH (node:Entity | Community {uuid: $entity_uuid})
MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
RETURN e.uuid AS uuid
"""
COMMUNITY_EDGE_RETURN = """

View file

@ -18,21 +18,45 @@ from typing import Any
from graphiti_core.driver.driver import GraphProvider
EPISODIC_NODE_SAVE = """
MERGE (n:Episodic {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
RETURN n.uuid AS uuid
"""
EPISODIC_NODE_SAVE_BULK = """
UNWIND $episodes AS episode
MERGE (n:Episodic {uuid: episode.uuid})
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description,
source: episode.source, content: episode.content,
entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
RETURN n.uuid AS uuid
"""
def get_episode_node_save_query(provider: GraphProvider) -> str:
match provider:
case GraphProvider.NEPTUNE:
return """
MERGE (n:Episodic {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
entity_edges: join([x IN coalesce($entity_edges, []) | toString(x) ], '|'), created_at: $created_at, valid_at: $valid_at}
RETURN n.uuid AS uuid
"""
case _: # Neo4j and FalkorDB
return """
MERGE (n:Episodic {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
RETURN n.uuid AS uuid
"""
def get_episode_node_save_bulk_query(provider: GraphProvider) -> str:
match provider:
case GraphProvider.NEPTUNE:
return """
UNWIND $episodes AS episode
MERGE (n:Episodic {uuid: episode.uuid})
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description,
source: episode.source, content: episode.content,
entity_edges: join([x IN coalesce(episode.entity_edges, []) | toString(x) ], '|'), created_at: episode.created_at, valid_at: episode.valid_at}
RETURN n.uuid AS uuid
"""
case _: # Neo4j and FalkorDB
return """
UNWIND $episodes AS episode
MERGE (n:Episodic {uuid: episode.uuid})
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description, source: episode.source, content: episode.content,
entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
RETURN n.uuid AS uuid
"""
EPISODIC_NODE_RETURN = """
e.content AS content,
@ -46,54 +70,96 @@ EPISODIC_NODE_RETURN = """
e.entity_edges AS entity_edges
"""
EPISODIC_NODE_RETURN_NEPTUNE = """
e.content AS content,
e.created_at AS created_at,
e.valid_at AS valid_at,
e.uuid AS uuid,
e.name AS name,
e.group_id AS group_id,
e.source_description AS source_description,
e.source AS source,
split(e.entity_edges, ",") AS entity_edges
"""
def get_entity_node_save_query(provider: GraphProvider, labels: str) -> str:
if provider == GraphProvider.FALKORDB:
return f"""
MERGE (n:Entity {{uuid: $entity_data.uuid}})
SET n:{labels}
SET n = $entity_data
RETURN n.uuid AS uuid
"""
return f"""
MERGE (n:Entity {{uuid: $entity_data.uuid}})
SET n:{labels}
SET n = $entity_data
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding)
RETURN n.uuid AS uuid
"""
match provider:
case GraphProvider.FALKORDB:
return f"""
MERGE (n:Entity {{uuid: $entity_data.uuid}})
SET n:{labels}
SET n = $entity_data
RETURN n.uuid AS uuid
"""
case GraphProvider.NEPTUNE:
label_subquery = ''
for label in labels.split(':'):
label_subquery += f' SET n:{label}\n'
return f"""
MERGE (n:Entity {{uuid: $entity_data.uuid}})
{label_subquery}
SET n = removeKeyFromMap(removeKeyFromMap($entity_data, "labels"), "name_embedding")
SET n.name_embedding = join([x IN coalesce($entity_data.name_embedding, []) | toString(x) ], ",")
RETURN n.uuid AS uuid
"""
case _:
return f"""
MERGE (n:Entity {{uuid: $entity_data.uuid}})
SET n:{labels}
SET n = $entity_data
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding)
RETURN n.uuid AS uuid
"""
def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict]) -> str | Any:
if provider == GraphProvider.FALKORDB:
queries = []
for node in nodes:
for label in node['labels']:
match provider:
case GraphProvider.FALKORDB:
queries = []
for node in nodes:
for label in node['labels']:
queries.append(
(
f"""
UNWIND $nodes AS node
MERGE (n:Entity {{uuid: node.uuid}})
SET n:{label}
SET n = node
WITH n, node
SET n.name_embedding = vecf32(node.name_embedding)
RETURN n.uuid AS uuid
""",
{'nodes': [node]},
)
)
return queries
case GraphProvider.NEPTUNE:
queries = []
for node in nodes:
labels = ''
for label in node['labels']:
labels += f' SET n:{label}\n'
queries.append(
(
f"""
f"""
UNWIND $nodes AS node
MERGE (n:Entity {{uuid: node.uuid}})
SET n:{label}
SET n = node
WITH n, node
SET n.name_embedding = vecf32(node.name_embedding)
{labels}
SET n = removeKeyFromMap(removeKeyFromMap(node, "labels"), "name_embedding")
SET n.name_embedding = join([x IN coalesce(node.name_embedding, []) | toString(x) ], ",")
RETURN n.uuid AS uuid
""",
{'nodes': [node]},
)
"""
)
return queries
return """
UNWIND $nodes AS node
MERGE (n:Entity {uuid: node.uuid})
SET n:$(node.labels)
SET n = node
WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)
RETURN n.uuid AS uuid
"""
return queries
case _: # Neo4j
return """
UNWIND $nodes AS node
MERGE (n:Entity {uuid: node.uuid})
SET n:$(node.labels)
SET n = node
WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)
RETURN n.uuid AS uuid
"""
ENTITY_NODE_RETURN = """
@ -108,19 +174,27 @@ ENTITY_NODE_RETURN = """
def get_community_node_save_query(provider: GraphProvider) -> str:
if provider == GraphProvider.FALKORDB:
return """
MERGE (n:Community {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at, name_embedding: vecf32($name_embedding)}
RETURN n.uuid AS uuid
"""
return """
MERGE (n:Community {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
RETURN n.uuid AS uuid
"""
match provider:
case GraphProvider.FALKORDB:
return """
MERGE (n:Community {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at, name_embedding: vecf32($name_embedding)}
RETURN n.uuid AS uuid
"""
case GraphProvider.NEPTUNE:
return """
MERGE (n:Community {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
SET n.name_embedding = join([x IN coalesce($name_embedding, []) | toString(x) ], ",")
RETURN n.uuid AS uuid
"""
case _: # Neo4j
return """
MERGE (n:Community {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
RETURN n.uuid AS uuid
"""
COMMUNITY_NODE_RETURN = """
@ -131,3 +205,12 @@ COMMUNITY_NODE_RETURN = """
n.summary AS summary,
n.created_at AS created_at
"""
COMMUNITY_NODE_RETURN_NEPTUNE = """
n.uuid AS uuid,
n.name AS name,
[x IN split(n.name_embedding, ",") | toFloat(x)] AS name_embedding,
n.group_id AS group_id,
n.summary AS summary,
n.created_at AS created_at
"""

View file

@ -31,11 +31,13 @@ from graphiti_core.errors import NodeNotFoundError
from graphiti_core.helpers import parse_db_date
from graphiti_core.models.nodes.node_db_queries import (
COMMUNITY_NODE_RETURN,
COMMUNITY_NODE_RETURN_NEPTUNE,
ENTITY_NODE_RETURN,
EPISODIC_NODE_RETURN,
EPISODIC_NODE_SAVE,
EPISODIC_NODE_RETURN_NEPTUNE,
get_community_node_save_query,
get_entity_node_save_query,
get_episode_node_save_query,
)
from graphiti_core.utils.datetime_utils import utc_now
@ -89,23 +91,24 @@ class Node(BaseModel, ABC):
async def save(self, driver: GraphDriver): ...
async def delete(self, driver: GraphDriver):
if driver.provider == GraphProvider.FALKORDB:
for label in ['Entity', 'Episodic', 'Community']:
match driver.provider:
case GraphProvider.NEO4J:
await driver.execute_query(
f"""
MATCH (n:{label} {{uuid: $uuid}})
DETACH DELETE n
""",
uuid=self.uuid,
)
else:
await driver.execute_query(
"""
"""
MATCH (n:Entity|Episodic|Community {uuid: $uuid})
DETACH DELETE n
""",
uuid=self.uuid,
)
uuid=self.uuid,
)
case _: # FalkorDB and Neptune
for label in ['Entity', 'Episodic', 'Community']:
await driver.execute_query(
f"""
MATCH (n:{label} {{uuid: $uuid}})
DETACH DELETE n
""",
uuid=self.uuid,
)
logger.debug(f'Deleted Node: {self.uuid}')
@ -119,28 +122,30 @@ class Node(BaseModel, ABC):
@classmethod
async def delete_by_group_id(cls, driver: GraphDriver, group_id: str, batch_size: int = 100):
if driver.provider == GraphProvider.FALKORDB:
for label in ['Entity', 'Episodic', 'Community']:
await driver.execute_query(
f"""
MATCH (n:{label} {{group_id: $group_id}})
DETACH DELETE n
""",
group_id=group_id,
)
else:
async with driver.session() as session:
await session.run(
"""
MATCH (n:Entity|Episodic|Community {group_id: $group_id})
CALL {
WITH n
match driver.provider:
case GraphProvider.NEO4J:
async with driver.session() as session:
await session.run(
"""
MATCH (n:Entity|Episodic|Community {group_id: $group_id})
CALL {
WITH n
DETACH DELETE n
} IN TRANSACTIONS OF $batch_size ROWS
""",
group_id=group_id,
batch_size=batch_size,
)
case _: # FalkorDB and Neptune
for label in ['Entity', 'Episodic', 'Community']:
await driver.execute_query(
f"""
MATCH (n:{label} {{group_id: $group_id}})
DETACH DELETE n
} IN TRANSACTIONS OF $batch_size ROWS
""",
group_id=group_id,
batch_size=batch_size,
)
""",
group_id=group_id,
)
@classmethod
async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str], batch_size: int = 100):
@ -189,8 +194,21 @@ class EpisodicNode(Node):
)
async def save(self, driver: GraphDriver):
if driver.provider == GraphProvider.NEPTUNE:
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
'episode_content',
[
{
'uuid': self.uuid,
'group_id': self.group_id,
'source': self.source.value,
'content': self.content,
'source_description': self.source_description,
}
],
)
result = await driver.execute_query(
EPISODIC_NODE_SAVE,
get_episode_node_save_query(driver.provider),
uuid=self.uuid,
name=self.name,
group_id=self.group_id,
@ -213,7 +231,11 @@ class EpisodicNode(Node):
MATCH (e:Episodic {uuid: $uuid})
RETURN
"""
+ EPISODIC_NODE_RETURN,
+ (
EPISODIC_NODE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else EPISODIC_NODE_RETURN
),
uuid=uuid,
routing_='r',
)
@ -233,7 +255,11 @@ class EpisodicNode(Node):
WHERE e.uuid IN $uuids
RETURN DISTINCT
"""
+ EPISODIC_NODE_RETURN,
+ (
EPISODIC_NODE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else EPISODIC_NODE_RETURN
),
uuids=uuids,
routing_='r',
)
@ -262,7 +288,11 @@ class EpisodicNode(Node):
+ """
RETURN DISTINCT
"""
+ EPISODIC_NODE_RETURN
+ (
EPISODIC_NODE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else EPISODIC_NODE_RETURN
)
+ """
ORDER BY uuid DESC
"""
@ -284,7 +314,11 @@ class EpisodicNode(Node):
MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid})
RETURN DISTINCT
"""
+ EPISODIC_NODE_RETURN,
+ (
EPISODIC_NODE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else EPISODIC_NODE_RETURN
),
entity_node_uuid=entity_node_uuid,
routing_='r',
)
@ -311,11 +345,18 @@ class EntityNode(Node):
return self.name_embedding
async def load_name_embedding(self, driver: GraphDriver):
records, _, _ = await driver.execute_query(
if driver.provider == GraphProvider.NEPTUNE:
query: LiteralString = """
MATCH (n:Entity {uuid: $uuid})
RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding
"""
MATCH (n:Entity {uuid: $uuid})
RETURN n.name_embedding AS name_embedding
""",
else:
query: LiteralString = """
MATCH (n:Entity {uuid: $uuid})
RETURN n.name_embedding AS name_embedding
"""
records, _, _ = await driver.execute_query(
query,
uuid=self.uuid,
routing_='r',
)
@ -336,6 +377,9 @@ class EntityNode(Node):
}
entity_data.update(self.attributes or {})
if driver.provider == GraphProvider.NEPTUNE:
driver.save_to_aoss('node_name_and_summary', [entity_data]) # pyright: ignore reportAttributeAccessIssue
labels = ':'.join(self.labels + ['Entity'])
result = await driver.execute_query(
@ -433,8 +477,13 @@ class CommunityNode(Node):
summary: str = Field(description='region summary of member nodes', default_factory=str)
async def save(self, driver: GraphDriver):
if driver.provider == GraphProvider.NEPTUNE:
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
'community_name',
[{'name': self.name, 'uuid': self.uuid, 'group_id': self.group_id}],
)
result = await driver.execute_query(
get_community_node_save_query(driver.provider),
get_community_node_save_query(driver.provider), # type: ignore
uuid=self.uuid,
name=self.name,
group_id=self.group_id,
@ -457,11 +506,19 @@ class CommunityNode(Node):
return self.name_embedding
async def load_name_embedding(self, driver: GraphDriver):
records, _, _ = await driver.execute_query(
if driver.provider == GraphProvider.NEPTUNE:
query: LiteralString = """
MATCH (c:Community {uuid: $uuid})
RETURN [x IN split(c.name_embedding, ",") | toFloat(x)] as name_embedding
"""
else:
query: LiteralString = """
MATCH (c:Community {uuid: $uuid})
RETURN c.name_embedding AS name_embedding
""",
"""
records, _, _ = await driver.execute_query(
query,
uuid=self.uuid,
routing_='r',
)
@ -478,7 +535,11 @@ class CommunityNode(Node):
MATCH (n:Community {uuid: $uuid})
RETURN
"""
+ COMMUNITY_NODE_RETURN,
+ (
COMMUNITY_NODE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else COMMUNITY_NODE_RETURN
),
uuid=uuid,
routing_='r',
)
@ -498,7 +559,11 @@ class CommunityNode(Node):
WHERE n.uuid IN $uuids
RETURN
"""
+ COMMUNITY_NODE_RETURN,
+ (
COMMUNITY_NODE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else COMMUNITY_NODE_RETURN
),
uuids=uuids,
routing_='r',
)
@ -527,7 +592,11 @@ class CommunityNode(Node):
+ """
RETURN
"""
+ COMMUNITY_NODE_RETURN
+ (
COMMUNITY_NODE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else COMMUNITY_NODE_RETURN
)
+ """
ORDER BY n.uuid DESC
"""

File diff suppressed because it is too large Load diff

View file

@ -32,8 +32,8 @@ from graphiti_core.models.edges.edge_db_queries import (
get_entity_edge_save_bulk_query,
)
from graphiti_core.models.nodes.node_db_queries import (
EPISODIC_NODE_SAVE_BULK,
get_entity_node_save_bulk_query,
get_episode_node_save_bulk_query,
)
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
from graphiti_core.utils.maintenance.edge_operations import (
@ -155,7 +155,7 @@ async def add_nodes_and_edges_bulk_tx(
edge_data.update(edge.attributes or {})
edges.append(edge_data)
await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
entity_node_save_bulk = get_entity_node_save_bulk_query(driver.provider, nodes)
await tx.run(entity_node_save_bulk, nodes=nodes)
await tx.run(

View file

@ -21,7 +21,7 @@ from time import time
from pydantic import BaseModel
from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.driver.driver import GraphDriver, GraphProvider
from graphiti_core.edges import (
CommunityEdge,
EntityEdge,
@ -504,23 +504,46 @@ async def resolve_extracted_edge(
async def filter_existing_duplicate_of_edges(
driver: GraphDriver, duplicates_node_tuples: list[tuple[EntityNode, EntityNode]]
) -> list[tuple[EntityNode, EntityNode]]:
query: LiteralString = """
UNWIND $duplicate_node_uuids AS duplicate_tuple
MATCH (n:Entity {uuid: duplicate_tuple[0]})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple[1]})
RETURN DISTINCT
n.uuid AS source_uuid,
m.uuid AS target_uuid
"""
if not duplicates_node_tuples:
return []
duplicate_nodes_map = {
(source.uuid, target.uuid): (source, target) for source, target in duplicates_node_tuples
}
records, _, _ = await driver.execute_query(
query,
duplicate_node_uuids=list(duplicate_nodes_map.keys()),
routing_='r',
)
if driver.provider == GraphProvider.NEPTUNE:
query: LiteralString = """
UNWIND $duplicate_node_uuids AS duplicate_tuple
MATCH (n:Entity {uuid: duplicate_tuple.source})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple.target})
RETURN DISTINCT
n.uuid AS source_uuid,
m.uuid AS target_uuid
"""
duplicate_nodes = [
{'source': source.uuid, 'target': target.uuid}
for source, target in duplicates_node_tuples
]
records, _, _ = await driver.execute_query(
query,
duplicate_node_uuids=duplicate_nodes,
routing_='r',
)
else:
query: LiteralString = """
UNWIND $duplicate_node_uuids AS duplicate_tuple
MATCH (n:Entity {uuid: duplicate_tuple[0]})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple[1]})
RETURN DISTINCT
n.uuid AS source_uuid,
m.uuid AS target_uuid
"""
records, _, _ = await driver.execute_query(
query,
duplicate_node_uuids=list(duplicate_nodes_map.keys()),
routing_='r',
)
# Remove duplicates that already have the IS_DUPLICATE_OF edge
for record in records:

View file

@ -19,10 +19,13 @@ from datetime import datetime
from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.driver.driver import GraphDriver, GraphProvider
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
from graphiti_core.helpers import semaphore_gather
from graphiti_core.models.nodes.node_db_queries import EPISODIC_NODE_RETURN
from graphiti_core.models.nodes.node_db_queries import (
EPISODIC_NODE_RETURN,
EPISODIC_NODE_RETURN_NEPTUNE,
)
from graphiti_core.nodes import EpisodeType, EpisodicNode, get_episodic_node_from_record
EPISODE_WINDOW_LEN = 3
@ -31,6 +34,8 @@ logger = logging.getLogger(__name__)
async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bool = False):
if driver.provider == GraphProvider.NEPTUNE:
return # Neptune does not need indexes built
if delete_existing:
records, _, _ = await driver.execute_query(
"""
@ -71,7 +76,7 @@ async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None):
async def delete_group_ids(tx):
await tx.run(
'MATCH (n:Entity|Episodic|Community) WHERE n.group_id IN $group_ids DETACH DELETE n',
'MATCH (n) WHERE (n:Entity OR n:Episodic OR n:Community) AND n.group_id IN $group_ids DETACH DELETE n',
group_ids=group_ids,
)
@ -117,7 +122,11 @@ async def retrieve_episodes(
+ """
RETURN
"""
+ EPISODIC_NODE_RETURN
+ (
EPISODIC_NODE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else EPISODIC_NODE_RETURN
)
+ """
ORDER BY e.valid_at DESC
LIMIT $num_episodes

View file

@ -18,7 +18,7 @@ dependencies = [
"tenacity>=9.0.0",
"numpy>=1.0.0",
"python-dotenv>=1.0.1",
"posthog>=3.0.0",
"posthog>=3.0.0"
]
[project.urls]
@ -32,6 +32,7 @@ google-genai = ["google-genai>=1.8.0"]
falkordb = ["falkordb>=1.1.2,<2.0.0"]
voyageai = ["voyageai>=0.2.3"]
sentence-transformers = ["sentence-transformers>=3.2.1"]
neptune = ["langchain-aws>=0.2.29", "opensearch-py>=3.0.0", "boto3>=1.39.16"]
dev = [
"pyright>=1.1.380",
"groq>=0.2.0",

View file

@ -20,12 +20,14 @@ import pytest
from dotenv import load_dotenv
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.driver.neptune_driver import NeptuneDriver
from graphiti_core.helpers import lucene_sanitize
load_dotenv()
HAS_NEO4J = False
HAS_FALKORDB = False
HAS_NEPTUNE = False
if os.getenv('DISABLE_NEO4J') is None:
try:
from graphiti_core.driver.neo4j_driver import Neo4jDriver
@ -42,6 +44,14 @@ if os.getenv('DISABLE_FALKORDB') is None:
except ImportError:
pass
if os.getenv('DISABLE_NEPTUNE') is None:
try:
from graphiti_core.driver.neptune_driver import NeptuneDriver
HAS_NEPTUNE = True
except ImportError:
pass
NEO4J_URI = os.getenv('NEO4J_URI', 'bolt://localhost:7687')
NEO4J_USER = os.getenv('NEO4J_USER', 'neo4j')
NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD', 'test')
@ -51,6 +61,10 @@ FALKORDB_PORT = os.getenv('FALKORDB_PORT', '6379')
FALKORDB_USER = os.getenv('FALKORDB_USER', None)
FALKORDB_PASSWORD = os.getenv('FALKORDB_PASSWORD', None)
NEPTUNE_HOST = os.getenv('NEPTUNE_HOST', 'localhost')
NEPTUNE_PORT = os.getenv('NEPTUNE_PORT', 8182)
AOSS_HOST = os.getenv('AOSS_HOST', None)
def get_driver(driver_name: str) -> GraphDriver:
if driver_name == 'neo4j':
@ -66,6 +80,12 @@ def get_driver(driver_name: str) -> GraphDriver:
username=FALKORDB_USER,
password=FALKORDB_PASSWORD,
)
elif driver_name == 'neptune':
return NeptuneDriver(
host=NEPTUNE_HOST,
port=int(NEPTUNE_PORT),
aoss_host=AOSS_HOST,
)
else:
raise ValueError(f'Driver {driver_name} not available')
@ -75,6 +95,8 @@ if HAS_NEO4J:
drivers.append('neo4j')
if HAS_FALKORDB:
drivers.append('falkordb')
if HAS_NEPTUNE:
drivers.append('neptune')
def test_lucene_sanitize():

4130
uv.lock generated

File diff suppressed because it is too large Load diff