Amazon Neptune Support (#793)

* Rebased Neptune changes based on significant rework done

* Updated the README documentation

* Fixed linting and formatting

* Update README.md

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>

* Update graphiti_core/driver/neptune_driver.py

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>

* Update README.md

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>

* Addressed feedback from code review

* Updated the README documentation for clarity

* Updated the README and neptune_driver based on PR feedback

* Update node_db_queries.py

---------

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
Co-authored-by: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com>
This commit is contained in:
bechbd 2025-08-20 06:56:03 -08:00 committed by GitHub
parent 9c1e1ad7ef
commit ef56dc779a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 3805 additions and 2460 deletions

View file

@ -105,7 +105,7 @@ Graphiti is specifically designed to address the challenges of dynamic and frequ
Requirements: Requirements:
- Python 3.10 or higher - Python 3.10 or higher
- Neo4j 5.26 / FalkorDB 1.1.2 or higher (serves as the embeddings storage backend) - Neo4j 5.26 / FalkorDB 1.1.2 / Amazon Neptune Database Cluster or Neptune Analytics Graph + Amazon OpenSearch Serverless collection (serves as the full text search backend)
- OpenAI API key (Graphiti defaults to OpenAI for LLM inference and embedding) - OpenAI API key (Graphiti defaults to OpenAI for LLM inference and embedding)
> [!IMPORTANT] > [!IMPORTANT]
@ -148,6 +148,17 @@ pip install graphiti-core[falkordb]
uv add graphiti-core[falkordb] uv add graphiti-core[falkordb]
``` ```
### Installing with Amazon Neptune Support
If you plan to use Amazon Neptune as your graph database backend, install with the Amazon Neptune extra:
```bash
pip install graphiti-core[neptune]
# or with uv
uv add graphiti-core[neptune]
```
### You can also install optional LLM providers as extras: ### You can also install optional LLM providers as extras:
```bash ```bash
@ -165,6 +176,9 @@ pip install graphiti-core[anthropic,groq,google-genai]
# Install with FalkorDB and LLM providers # Install with FalkorDB and LLM providers
pip install graphiti-core[falkordb,anthropic,google-genai] pip install graphiti-core[falkordb,anthropic,google-genai]
# Install with Amazon Neptune
pip install graphiti-core[neptune]
``` ```
## Default to Low Concurrency; LLM Provider 429 Rate Limit Errors ## Default to Low Concurrency; LLM Provider 429 Rate Limit Errors
@ -184,7 +198,7 @@ If your LLM provider allows higher throughput, you can increase `SEMAPHORE_LIMIT
For a complete working example, see the [Quickstart Example](./examples/quickstart/README.md) in the examples directory. The quickstart demonstrates: For a complete working example, see the [Quickstart Example](./examples/quickstart/README.md) in the examples directory. The quickstart demonstrates:
1. Connecting to a Neo4j or FalkorDB database 1. Connecting to a Neo4j, Amazon Neptune, or FalkorDB database
2. Initializing Graphiti indices and constraints 2. Initializing Graphiti indices and constraints
3. Adding episodes to the graph (both text and structured JSON) 3. Adding episodes to the graph (both text and structured JSON)
4. Searching for relationships (edges) using hybrid search 4. Searching for relationships (edges) using hybrid search
@ -267,6 +281,26 @@ driver = FalkorDriver(
graphiti = Graphiti(graph_driver=driver) graphiti = Graphiti(graph_driver=driver)
``` ```
#### Amazon Neptune
```python
from graphiti_core import Graphiti
from graphiti_core.driver.neptune_driver import NeptuneDriver
# Create a FalkorDB driver with custom database name
driver = NeptuneDriver(
host=<NEPTUNE ENDPOINT>,
aoss_host=<Amazon OpenSearch Serverless Host>,
port=<PORT> # Optional, defaults to 8182,
aoss_port=<PORT> # Optional, defaults to 443
)
driver = NeptuneDriver(host=neptune_uri, aoss_host=aoss_host, port=neptune_port)
# Pass the driver to Graphiti
graphiti = Graphiti(graph_driver=driver)
```
### Performance Configuration ### Performance Configuration
@ -458,7 +492,7 @@ When you initialize a Graphiti instance, we collect:
- **Graphiti version**: The version you're using - **Graphiti version**: The version you're using
- **Configuration choices**: - **Configuration choices**:
- LLM provider type (OpenAI, Azure, Anthropic, etc.) - LLM provider type (OpenAI, Azure, Anthropic, etc.)
- Database backend (Neo4j, FalkorDB) - Database backend (Neo4j, FalkorDB, Amazon Neptune Database or Neptune Analytics)
- Embedder provider (OpenAI, Azure, Voyage, etc.) - Embedder provider (OpenAI, Azure, Voyage, etc.)
### What We Don't Collect ### What We Don't Collect

View file

@ -18,6 +18,8 @@ This example demonstrates the basic functionality of Graphiti, including:
- A local DBMS created and started in Neo4j Desktop - A local DBMS created and started in Neo4j Desktop
- **For FalkorDB**: - **For FalkorDB**:
- FalkorDB server running (see [FalkorDB documentation](https://falkordb.com/docs/) for setup) - FalkorDB server running (see [FalkorDB documentation](https://falkordb.com/docs/) for setup)
- **For Amazon Neptune**:
- Amazon server running (see [Amazon Neptune documentation](https://aws.amazon.com/neptune/developer-resources/) for setup)
## Setup Instructions ## Setup Instructions
@ -42,9 +44,19 @@ export NEO4J_PASSWORD=password
# Optional FalkorDB connection parameters (defaults shown) # Optional FalkorDB connection parameters (defaults shown)
export FALKORDB_URI=falkor://localhost:6379 export FALKORDB_URI=falkor://localhost:6379
# Optional Amazon Neptune connection parameters
NEPTUNE_HOST=your_neptune_host
NEPTUNE_PORT=your_port_or_8182
AOSS_HOST=your_aoss_host
AOSS_PORT=your_port_or_443
# To use a different database, modify the driver constructor in the script # To use a different database, modify the driver constructor in the script
``` ```
TIP: For Amazon Neptune host string please use the following formats
* For Neptune Database: `neptune-db://<cluster endpoint>`
* For Neptune Analytics: `neptune-graph://<graph identifier>`
3. Run the example: 3. Run the example:
```bash ```bash
@ -52,11 +64,14 @@ python quickstart_neo4j.py
# For FalkorDB # For FalkorDB
python quickstart_falkordb.py python quickstart_falkordb.py
# For Amazon Neptune
python quickstart_neptune.py
``` ```
## What This Example Demonstrates ## What This Example Demonstrates
- **Graph Initialization**: Setting up the Graphiti indices and constraints in Neo4j or FalkorDB - **Graph Initialization**: Setting up the Graphiti indices and constraints in Neo4j, Amazon Neptune, or FalkorDB
- **Adding Episodes**: Adding text content that will be analyzed and converted into knowledge graph nodes and edges - **Adding Episodes**: Adding text content that will be analyzed and converted into knowledge graph nodes and edges
- **Edge Search Functionality**: Performing hybrid searches that combine semantic similarity and BM25 retrieval to find relationships (edges) - **Edge Search Functionality**: Performing hybrid searches that combine semantic similarity and BM25 retrieval to find relationships (edges)
- **Graph-Aware Search**: Using the source node UUID from the top search result to rerank additional search results based on graph distance - **Graph-Aware Search**: Using the source node UUID from the top search result to rerank additional search results based on graph distance

View file

@ -27,6 +27,7 @@ logger = logging.getLogger(__name__)
class GraphProvider(Enum): class GraphProvider(Enum):
NEO4J = 'neo4j' NEO4J = 'neo4j'
FALKORDB = 'falkordb' FALKORDB = 'falkordb'
NEPTUNE = 'neptune'
class GraphDriverSession(ABC): class GraphDriverSession(ABC):

View file

@ -0,0 +1,299 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import asyncio
import datetime
import logging
from collections.abc import Coroutine
from typing import Any
import boto3
from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
logger = logging.getLogger(__name__)
DEFAULT_SIZE = 10
aoss_indices = [
{
'index_name': 'node_name_and_summary',
'body': {
'mappings': {
'properties': {
'uuid': {'type': 'keyword'},
'name': {'type': 'text'},
'summary': {'type': 'text'},
'group_id': {'type': 'text'},
}
}
},
'query': {
'query': {'multi_match': {'query': '', 'fields': ['name', 'summary', 'group_id']}},
'size': DEFAULT_SIZE,
},
},
{
'index_name': 'community_name',
'body': {
'mappings': {
'properties': {
'uuid': {'type': 'keyword'},
'name': {'type': 'text'},
'group_id': {'type': 'text'},
}
}
},
'query': {
'query': {'multi_match': {'query': '', 'fields': ['name', 'group_id']}},
'size': DEFAULT_SIZE,
},
},
{
'index_name': 'episode_content',
'body': {
'mappings': {
'properties': {
'uuid': {'type': 'keyword'},
'content': {'type': 'text'},
'source': {'type': 'text'},
'source_description': {'type': 'text'},
'group_id': {'type': 'text'},
}
}
},
'query': {
'query': {
'multi_match': {
'query': '',
'fields': ['content', 'source', 'source_description', 'group_id'],
}
},
'size': DEFAULT_SIZE,
},
},
{
'index_name': 'edge_name_and_fact',
'body': {
'mappings': {
'properties': {
'uuid': {'type': 'keyword'},
'name': {'type': 'text'},
'fact': {'type': 'text'},
'group_id': {'type': 'text'},
}
}
},
'query': {
'query': {'multi_match': {'query': '', 'fields': ['name', 'fact', 'group_id']}},
'size': DEFAULT_SIZE,
},
},
]
class NeptuneDriver(GraphDriver):
provider: GraphProvider = GraphProvider.NEPTUNE
def __init__(self, host: str, aoss_host: str, port: int = 8182, aoss_port: int = 443):
"""This initializes a NeptuneDriver for use with Neptune as a backend
Args:
host (str): The Neptune Database or Neptune Analytics host
aoss_host (str): The OpenSearch host value
port (int, optional): The Neptune Database port, ignored for Neptune Analytics. Defaults to 8182.
aoss_port (int, optional): The OpenSearch port. Defaults to 443.
"""
if not host:
raise ValueError('You must provide an endpoint to create a NeptuneDriver')
if host.startswith('neptune-db://'):
# This is a Neptune Database Cluster
endpoint = host.replace('neptune-db://', '')
self.client = NeptuneGraph(endpoint, port)
logger.debug('Creating Neptune Database session for %s', host)
elif host.startswith('neptune-graph://'):
# This is a Neptune Analytics Graph
graphId = host.replace('neptune-graph://', '')
self.client = NeptuneAnalyticsGraph(graphId)
logger.debug('Creating Neptune Graph session for %s', host)
else:
raise ValueError(
'You must provide an endpoint to create a NeptuneDriver as either neptune-db://<endpoint> or neptune-graph://<graphid>'
)
if not aoss_host:
raise ValueError('You must provide an AOSS endpoint to create an OpenSearch driver.')
session = boto3.Session()
self.aoss_client = OpenSearch(
hosts=[{'host': aoss_host, 'port': aoss_port}],
http_auth=Urllib3AWSV4SignerAuth(
session.get_credentials(), session.region_name, 'aoss'
),
use_ssl=True,
verify_certs=True,
connection_class=Urllib3HttpConnection,
pool_maxsize=20,
)
def _sanitize_parameters(self, query, params: dict):
if isinstance(query, list):
queries = []
for q in query:
queries.append(self._sanitize_parameters(q, params))
return queries
else:
for k, v in params.items():
if isinstance(v, datetime.datetime):
params[k] = v.isoformat()
elif isinstance(v, list):
# Handle lists that might contain datetime objects
for i, item in enumerate(v):
if isinstance(item, datetime.datetime):
v[i] = item.isoformat()
query = str(query).replace(f'${k}', f'datetime(${k})')
if isinstance(item, dict):
query = self._sanitize_parameters(query, v[i])
# If the list contains datetime objects, we need to wrap each element with datetime()
if any(isinstance(item, str) and 'T' in item for item in v):
# Create a new list expression with datetime() wrapped around each element
datetime_list = (
'['
+ ', '.join(
f'datetime("{item}")'
if isinstance(item, str) and 'T' in item
else repr(item)
for item in v
)
+ ']'
)
query = str(query).replace(f'${k}', datetime_list)
elif isinstance(v, dict):
query = self._sanitize_parameters(query, v)
return query
async def execute_query(
self, cypher_query_, **kwargs: Any
) -> tuple[dict[str, Any], None, None]:
params = dict(kwargs)
if isinstance(cypher_query_, list):
for q in cypher_query_:
result, _, _ = self._run_query(q[0], q[1])
return result, None, None
else:
return self._run_query(cypher_query_, params)
def _run_query(self, cypher_query_, params):
cypher_query_ = str(self._sanitize_parameters(cypher_query_, params))
try:
result = self.client.query(cypher_query_, params=params)
except Exception as e:
logger.error('Query: %s', cypher_query_)
logger.error('Parameters: %s', params)
logger.error('Error executing query: %s', e)
raise e
return result, None, None
def session(self, database: str | None = None) -> GraphDriverSession:
return NeptuneDriverSession(driver=self)
async def close(self) -> None:
return self.client.client.close()
async def _delete_all_data(self) -> Any:
return await self.execute_query('MATCH (n) DETACH DELETE n')
def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
return self.delete_all_indexes_impl()
async def delete_all_indexes_impl(self) -> Coroutine[Any, Any, Any]:
# No matter what happens above, always return True
return self.delete_aoss_indices()
async def create_aoss_indices(self):
for index in aoss_indices:
index_name = index['index_name']
client = self.aoss_client
if not client.indices.exists(index=index_name):
client.indices.create(index=index_name, body=index['body'])
# Sleep for 1 minute to let the index creation complete
await asyncio.sleep(60)
async def delete_aoss_indices(self):
for index in aoss_indices:
index_name = index['index_name']
client = self.aoss_client
if client.indices.exists(index=index_name):
client.indices.delete(index=index_name)
def run_aoss_query(self, name: str, query_text: str, limit: int = 10) -> dict[str, Any]:
for index in aoss_indices:
if name.lower() == index['index_name']:
index['query']['query']['multi_match']['query'] = query_text
query = {'size': limit, 'query': index['query']}
resp = self.aoss_client.search(body=query['query'], index=index['index_name'])
return resp
return {}
def save_to_aoss(self, name: str, data: list[dict]) -> int:
for index in aoss_indices:
if name.lower() == index['index_name']:
to_index = []
for d in data:
item = {'_index': name}
for p in index['body']['mappings']['properties']:
item[p] = d[p]
to_index.append(item)
success, failed = helpers.bulk(self.aoss_client, to_index, stats_only=True)
if failed > 0:
return success
else:
return 0
return 0
class NeptuneDriverSession(GraphDriverSession):
def __init__(self, driver: NeptuneDriver): # type: ignore[reportUnknownArgumentType]
self.driver = driver
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
# No cleanup needed for Neptune, but method must exist
pass
async def close(self):
# No explicit close needed for Neptune, but method must exist
pass
async def execute_write(self, func, *args, **kwargs):
# Directly await the provided async function with `self` as the transaction/session
return await func(self, *args, **kwargs)
async def run(self, query: str | list, **kwargs: Any) -> Any:
if isinstance(query, list):
res = None
for q in query:
res = await self.driver.execute_query(q, **kwargs)
return res
else:
return await self.driver.execute_query(str(query), **kwargs)

View file

@ -24,13 +24,14 @@ from uuid import uuid4
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import LiteralString from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphDriver from graphiti_core.driver.driver import GraphDriver, GraphProvider
from graphiti_core.embedder import EmbedderClient from graphiti_core.embedder import EmbedderClient
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
from graphiti_core.helpers import parse_db_date from graphiti_core.helpers import parse_db_date
from graphiti_core.models.edges.edge_db_queries import ( from graphiti_core.models.edges.edge_db_queries import (
COMMUNITY_EDGE_RETURN, COMMUNITY_EDGE_RETURN,
ENTITY_EDGE_RETURN, ENTITY_EDGE_RETURN,
ENTITY_EDGE_RETURN_NEPTUNE,
EPISODIC_EDGE_RETURN, EPISODIC_EDGE_RETURN,
EPISODIC_EDGE_SAVE, EPISODIC_EDGE_SAVE,
get_community_edge_save_query, get_community_edge_save_query,
@ -214,11 +215,19 @@ class EntityEdge(Edge):
return self.fact_embedding return self.fact_embedding
async def load_fact_embedding(self, driver: GraphDriver): async def load_fact_embedding(self, driver: GraphDriver):
records, _, _ = await driver.execute_query( if driver.provider == GraphProvider.NEPTUNE:
query: LiteralString = """
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
""" """
else:
query: LiteralString = """
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity) MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
RETURN e.fact_embedding AS fact_embedding RETURN e.fact_embedding AS fact_embedding
""", """
records, _, _ = await driver.execute_query(
query,
uuid=self.uuid, uuid=self.uuid,
routing_='r', routing_='r',
) )
@ -246,6 +255,9 @@ class EntityEdge(Edge):
edge_data.update(self.attributes or {}) edge_data.update(self.attributes or {})
if driver.provider == GraphProvider.NEPTUNE:
driver.save_to_aoss('edge_name_and_fact', [edge_data]) # pyright: ignore reportAttributeAccessIssue
result = await driver.execute_query( result = await driver.execute_query(
get_entity_edge_save_query(driver.provider), get_entity_edge_save_query(driver.provider),
edge_data=edge_data, edge_data=edge_data,
@ -262,7 +274,11 @@ class EntityEdge(Edge):
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity) MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
RETURN RETURN
""" """
+ ENTITY_EDGE_RETURN, + (
ENTITY_EDGE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else ENTITY_EDGE_RETURN
),
uuid=uuid, uuid=uuid,
routing_='r', routing_='r',
) )
@ -284,7 +300,11 @@ class EntityEdge(Edge):
WHERE e.uuid IN $uuids WHERE e.uuid IN $uuids
RETURN RETURN
""" """
+ ENTITY_EDGE_RETURN, + (
ENTITY_EDGE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else ENTITY_EDGE_RETURN
),
uuids=uuids, uuids=uuids,
routing_='r', routing_='r',
) )
@ -321,7 +341,11 @@ class EntityEdge(Edge):
+ """ + """
RETURN RETURN
""" """
+ ENTITY_EDGE_RETURN + (
ENTITY_EDGE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else ENTITY_EDGE_RETURN
)
+ with_embeddings_query + with_embeddings_query
+ """ + """
ORDER BY e.uuid DESC ORDER BY e.uuid DESC
@ -346,7 +370,11 @@ class EntityEdge(Edge):
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity) MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
RETURN RETURN
""" """
+ ENTITY_EDGE_RETURN, + (
ENTITY_EDGE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else ENTITY_EDGE_RETURN
),
node_uuid=node_uuid, node_uuid=node_uuid,
routing_='r', routing_='r',
) )

View file

@ -43,47 +43,70 @@ EPISODIC_EDGE_RETURN = """
def get_entity_edge_save_query(provider: GraphProvider) -> str: def get_entity_edge_save_query(provider: GraphProvider) -> str:
if provider == GraphProvider.FALKORDB: match provider:
return """ case GraphProvider.FALKORDB:
MATCH (source:Entity {uuid: $edge_data.source_uuid}) return """
MATCH (target:Entity {uuid: $edge_data.target_uuid}) MATCH (source:Entity {uuid: $edge_data.source_uuid})
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target) MATCH (target:Entity {uuid: $edge_data.target_uuid})
SET e = $edge_data MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
RETURN e.uuid AS uuid SET e = $edge_data
""" RETURN e.uuid AS uuid
"""
return """ case GraphProvider.NEPTUNE:
MATCH (source:Entity {uuid: $edge_data.source_uuid}) return """
MATCH (target:Entity {uuid: $edge_data.target_uuid}) MATCH (source:Entity {uuid: $edge_data.source_uuid})
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target) MATCH (target:Entity {uuid: $edge_data.target_uuid})
SET e = $edge_data MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
WITH e CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", $edge_data.fact_embedding) SET e = removeKeyFromMap(removeKeyFromMap($edge_data, "fact_embedding"), "episodes")
RETURN e.uuid AS uuid SET e.fact_embedding = join([x IN coalesce($edge_data.fact_embedding, []) | toString(x) ], ",")
""" SET e.episodes = join($edge_data.episodes, ",")
RETURN $edge_data.uuid AS uuid
"""
case _: # Neo4j
return """
MATCH (source:Entity {uuid: $edge_data.source_uuid})
MATCH (target:Entity {uuid: $edge_data.target_uuid})
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
SET e = $edge_data
WITH e CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", $edge_data.fact_embedding)
RETURN e.uuid AS uuid
"""
def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str: def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str:
if provider == GraphProvider.FALKORDB: match provider:
return """ case GraphProvider.FALKORDB:
UNWIND $entity_edges AS edge return """
MATCH (source:Entity {uuid: edge.source_node_uuid}) UNWIND $entity_edges AS edge
MATCH (target:Entity {uuid: edge.target_node_uuid}) MATCH (source:Entity {uuid: edge.source_node_uuid})
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target) MATCH (target:Entity {uuid: edge.target_node_uuid})
SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes, MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at, fact_embedding: vecf32(edge.fact_embedding)} SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
WITH r, edge created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at, fact_embedding: vecf32(edge.fact_embedding)}
RETURN edge.uuid AS uuid WITH r, edge
""" RETURN edge.uuid AS uuid
"""
return """ case GraphProvider.NEPTUNE:
UNWIND $entity_edges AS edge return """
MATCH (source:Entity {uuid: edge.source_node_uuid}) UNWIND $entity_edges AS edge
MATCH (target:Entity {uuid: edge.target_node_uuid}) MATCH (source:Entity {uuid: edge.source_node_uuid})
MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target) MATCH (target:Entity {uuid: edge.target_node_uuid})
SET e = edge MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
WITH e, edge CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", edge.fact_embedding) SET r = removeKeyFromMap(removeKeyFromMap(edge, "fact_embedding"), "episodes")
RETURN edge.uuid AS uuid SET r.fact_embedding = join([x IN coalesce(edge.fact_embedding, []) | toString(x) ], ",")
""" SET r.episodes = join(edge.episodes, ",")
RETURN edge.uuid AS uuid
"""
case _:
return """
UNWIND $entity_edges AS edge
MATCH (source:Entity {uuid: edge.source_node_uuid})
MATCH (target:Entity {uuid: edge.target_node_uuid})
MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target)
SET e = edge
WITH e, edge CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", edge.fact_embedding)
RETURN edge.uuid AS uuid
"""
ENTITY_EDGE_RETURN = """ ENTITY_EDGE_RETURN = """
@ -101,24 +124,51 @@ ENTITY_EDGE_RETURN = """
properties(e) AS attributes properties(e) AS attributes
""" """
ENTITY_EDGE_RETURN_NEPTUNE = """
e.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
e.group_id AS group_id,
e.name AS name,
e.fact AS fact,
split(e.episodes, ',') AS episodes,
e.created_at AS created_at,
e.expired_at AS expired_at,
e.valid_at AS valid_at,
e.invalid_at AS invalid_at,
properties(e) AS attributes
"""
def get_community_edge_save_query(provider: GraphProvider) -> str: def get_community_edge_save_query(provider: GraphProvider) -> str:
if provider == GraphProvider.FALKORDB: match provider:
return """ case GraphProvider.FALKORDB:
MATCH (community:Community {uuid: $community_uuid}) return """
MATCH (node {uuid: $entity_uuid}) MATCH (community:Community {uuid: $community_uuid})
MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node) MATCH (node {uuid: $entity_uuid})
SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at} MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
RETURN e.uuid AS uuid SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
""" RETURN e.uuid AS uuid
"""
return """ case GraphProvider.NEPTUNE:
MATCH (community:Community {uuid: $community_uuid}) return """
MATCH (node:Entity | Community {uuid: $entity_uuid}) MATCH (community:Community {uuid: $community_uuid})
MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node) MATCH (node {uuid: $entity_uuid})
SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at} WHERE node:Entity OR node:Community
RETURN e.uuid AS uuid MERGE (community)-[r:HAS_MEMBER {uuid: $uuid}]->(node)
""" SET r.uuid= $uuid
SET r.group_id= $group_id
SET r.created_at= $created_at
RETURN r.uuid AS uuid
"""
case _: # Neo4j
return """
MATCH (community:Community {uuid: $community_uuid})
MATCH (node:Entity | Community {uuid: $entity_uuid})
MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
RETURN e.uuid AS uuid
"""
COMMUNITY_EDGE_RETURN = """ COMMUNITY_EDGE_RETURN = """

View file

@ -18,21 +18,45 @@ from typing import Any
from graphiti_core.driver.driver import GraphProvider from graphiti_core.driver.driver import GraphProvider
EPISODIC_NODE_SAVE = """
MERGE (n:Episodic {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
RETURN n.uuid AS uuid
"""
EPISODIC_NODE_SAVE_BULK = """ def get_episode_node_save_query(provider: GraphProvider) -> str:
UNWIND $episodes AS episode match provider:
MERGE (n:Episodic {uuid: episode.uuid}) case GraphProvider.NEPTUNE:
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description, return """
source: episode.source, content: episode.content, MERGE (n:Episodic {uuid: $uuid})
entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at} SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
RETURN n.uuid AS uuid entity_edges: join([x IN coalesce($entity_edges, []) | toString(x) ], '|'), created_at: $created_at, valid_at: $valid_at}
""" RETURN n.uuid AS uuid
"""
case _: # Neo4j and FalkorDB
return """
MERGE (n:Episodic {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
RETURN n.uuid AS uuid
"""
def get_episode_node_save_bulk_query(provider: GraphProvider) -> str:
match provider:
case GraphProvider.NEPTUNE:
return """
UNWIND $episodes AS episode
MERGE (n:Episodic {uuid: episode.uuid})
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description,
source: episode.source, content: episode.content,
entity_edges: join([x IN coalesce(episode.entity_edges, []) | toString(x) ], '|'), created_at: episode.created_at, valid_at: episode.valid_at}
RETURN n.uuid AS uuid
"""
case _: # Neo4j and FalkorDB
return """
UNWIND $episodes AS episode
MERGE (n:Episodic {uuid: episode.uuid})
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description, source: episode.source, content: episode.content,
entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
RETURN n.uuid AS uuid
"""
EPISODIC_NODE_RETURN = """ EPISODIC_NODE_RETURN = """
e.content AS content, e.content AS content,
@ -46,54 +70,96 @@ EPISODIC_NODE_RETURN = """
e.entity_edges AS entity_edges e.entity_edges AS entity_edges
""" """
EPISODIC_NODE_RETURN_NEPTUNE = """
e.content AS content,
e.created_at AS created_at,
e.valid_at AS valid_at,
e.uuid AS uuid,
e.name AS name,
e.group_id AS group_id,
e.source_description AS source_description,
e.source AS source,
split(e.entity_edges, ",") AS entity_edges
"""
def get_entity_node_save_query(provider: GraphProvider, labels: str) -> str: def get_entity_node_save_query(provider: GraphProvider, labels: str) -> str:
if provider == GraphProvider.FALKORDB: match provider:
return f""" case GraphProvider.FALKORDB:
MERGE (n:Entity {{uuid: $entity_data.uuid}}) return f"""
SET n:{labels} MERGE (n:Entity {{uuid: $entity_data.uuid}})
SET n = $entity_data SET n:{labels}
RETURN n.uuid AS uuid SET n = $entity_data
""" RETURN n.uuid AS uuid
"""
return f""" case GraphProvider.NEPTUNE:
MERGE (n:Entity {{uuid: $entity_data.uuid}}) label_subquery = ''
SET n:{labels} for label in labels.split(':'):
SET n = $entity_data label_subquery += f' SET n:{label}\n'
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding) return f"""
RETURN n.uuid AS uuid MERGE (n:Entity {{uuid: $entity_data.uuid}})
""" {label_subquery}
SET n = removeKeyFromMap(removeKeyFromMap($entity_data, "labels"), "name_embedding")
SET n.name_embedding = join([x IN coalesce($entity_data.name_embedding, []) | toString(x) ], ",")
RETURN n.uuid AS uuid
"""
case _:
return f"""
MERGE (n:Entity {{uuid: $entity_data.uuid}})
SET n:{labels}
SET n = $entity_data
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding)
RETURN n.uuid AS uuid
"""
def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict]) -> str | Any: def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict]) -> str | Any:
if provider == GraphProvider.FALKORDB: match provider:
queries = [] case GraphProvider.FALKORDB:
for node in nodes: queries = []
for label in node['labels']: for node in nodes:
for label in node['labels']:
queries.append(
(
f"""
UNWIND $nodes AS node
MERGE (n:Entity {{uuid: node.uuid}})
SET n:{label}
SET n = node
WITH n, node
SET n.name_embedding = vecf32(node.name_embedding)
RETURN n.uuid AS uuid
""",
{'nodes': [node]},
)
)
return queries
case GraphProvider.NEPTUNE:
queries = []
for node in nodes:
labels = ''
for label in node['labels']:
labels += f' SET n:{label}\n'
queries.append( queries.append(
( f"""
f"""
UNWIND $nodes AS node UNWIND $nodes AS node
MERGE (n:Entity {{uuid: node.uuid}}) MERGE (n:Entity {{uuid: node.uuid}})
SET n:{label} {labels}
SET n = node SET n = removeKeyFromMap(removeKeyFromMap(node, "labels"), "name_embedding")
WITH n, node SET n.name_embedding = join([x IN coalesce(node.name_embedding, []) | toString(x) ], ",")
SET n.name_embedding = vecf32(node.name_embedding)
RETURN n.uuid AS uuid RETURN n.uuid AS uuid
""", """
{'nodes': [node]},
)
) )
return queries return queries
case _: # Neo4j
return """ return """
UNWIND $nodes AS node UNWIND $nodes AS node
MERGE (n:Entity {uuid: node.uuid}) MERGE (n:Entity {uuid: node.uuid})
SET n:$(node.labels) SET n:$(node.labels)
SET n = node SET n = node
WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding) WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)
RETURN n.uuid AS uuid RETURN n.uuid AS uuid
""" """
ENTITY_NODE_RETURN = """ ENTITY_NODE_RETURN = """
@ -108,19 +174,27 @@ ENTITY_NODE_RETURN = """
def get_community_node_save_query(provider: GraphProvider) -> str: def get_community_node_save_query(provider: GraphProvider) -> str:
if provider == GraphProvider.FALKORDB: match provider:
return """ case GraphProvider.FALKORDB:
MERGE (n:Community {uuid: $uuid}) return """
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at, name_embedding: vecf32($name_embedding)} MERGE (n:Community {uuid: $uuid})
RETURN n.uuid AS uuid SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at, name_embedding: vecf32($name_embedding)}
""" RETURN n.uuid AS uuid
"""
return """ case GraphProvider.NEPTUNE:
MERGE (n:Community {uuid: $uuid}) return """
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at} MERGE (n:Community {uuid: $uuid})
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding) SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
RETURN n.uuid AS uuid SET n.name_embedding = join([x IN coalesce($name_embedding, []) | toString(x) ], ",")
""" RETURN n.uuid AS uuid
"""
case _: # Neo4j
return """
MERGE (n:Community {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
RETURN n.uuid AS uuid
"""
COMMUNITY_NODE_RETURN = """ COMMUNITY_NODE_RETURN = """
@ -131,3 +205,12 @@ COMMUNITY_NODE_RETURN = """
n.summary AS summary, n.summary AS summary,
n.created_at AS created_at n.created_at AS created_at
""" """
COMMUNITY_NODE_RETURN_NEPTUNE = """
n.uuid AS uuid,
n.name AS name,
[x IN split(n.name_embedding, ",") | toFloat(x)] AS name_embedding,
n.group_id AS group_id,
n.summary AS summary,
n.created_at AS created_at
"""

View file

@ -31,11 +31,13 @@ from graphiti_core.errors import NodeNotFoundError
from graphiti_core.helpers import parse_db_date from graphiti_core.helpers import parse_db_date
from graphiti_core.models.nodes.node_db_queries import ( from graphiti_core.models.nodes.node_db_queries import (
COMMUNITY_NODE_RETURN, COMMUNITY_NODE_RETURN,
COMMUNITY_NODE_RETURN_NEPTUNE,
ENTITY_NODE_RETURN, ENTITY_NODE_RETURN,
EPISODIC_NODE_RETURN, EPISODIC_NODE_RETURN,
EPISODIC_NODE_SAVE, EPISODIC_NODE_RETURN_NEPTUNE,
get_community_node_save_query, get_community_node_save_query,
get_entity_node_save_query, get_entity_node_save_query,
get_episode_node_save_query,
) )
from graphiti_core.utils.datetime_utils import utc_now from graphiti_core.utils.datetime_utils import utc_now
@ -89,23 +91,24 @@ class Node(BaseModel, ABC):
async def save(self, driver: GraphDriver): ... async def save(self, driver: GraphDriver): ...
async def delete(self, driver: GraphDriver): async def delete(self, driver: GraphDriver):
if driver.provider == GraphProvider.FALKORDB: match driver.provider:
for label in ['Entity', 'Episodic', 'Community']: case GraphProvider.NEO4J:
await driver.execute_query( await driver.execute_query(
f""" """
MATCH (n:{label} {{uuid: $uuid}})
DETACH DELETE n
""",
uuid=self.uuid,
)
else:
await driver.execute_query(
"""
MATCH (n:Entity|Episodic|Community {uuid: $uuid}) MATCH (n:Entity|Episodic|Community {uuid: $uuid})
DETACH DELETE n DETACH DELETE n
""", """,
uuid=self.uuid, uuid=self.uuid,
) )
case _: # FalkorDB and Neptune
for label in ['Entity', 'Episodic', 'Community']:
await driver.execute_query(
f"""
MATCH (n:{label} {{uuid: $uuid}})
DETACH DELETE n
""",
uuid=self.uuid,
)
logger.debug(f'Deleted Node: {self.uuid}') logger.debug(f'Deleted Node: {self.uuid}')
@ -119,28 +122,30 @@ class Node(BaseModel, ABC):
@classmethod @classmethod
async def delete_by_group_id(cls, driver: GraphDriver, group_id: str, batch_size: int = 100): async def delete_by_group_id(cls, driver: GraphDriver, group_id: str, batch_size: int = 100):
if driver.provider == GraphProvider.FALKORDB: match driver.provider:
for label in ['Entity', 'Episodic', 'Community']: case GraphProvider.NEO4J:
await driver.execute_query( async with driver.session() as session:
f""" await session.run(
MATCH (n:{label} {{group_id: $group_id}}) """
DETACH DELETE n MATCH (n:Entity|Episodic|Community {group_id: $group_id})
""", CALL {
group_id=group_id, WITH n
) DETACH DELETE n
else: } IN TRANSACTIONS OF $batch_size ROWS
async with driver.session() as session: """,
await session.run( group_id=group_id,
""" batch_size=batch_size,
MATCH (n:Entity|Episodic|Community {group_id: $group_id}) )
CALL {
WITH n case _: # FalkorDB and Neptune
for label in ['Entity', 'Episodic', 'Community']:
await driver.execute_query(
f"""
MATCH (n:{label} {{group_id: $group_id}})
DETACH DELETE n DETACH DELETE n
} IN TRANSACTIONS OF $batch_size ROWS """,
""", group_id=group_id,
group_id=group_id, )
batch_size=batch_size,
)
@classmethod @classmethod
async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str], batch_size: int = 100): async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str], batch_size: int = 100):
@ -189,8 +194,21 @@ class EpisodicNode(Node):
) )
async def save(self, driver: GraphDriver): async def save(self, driver: GraphDriver):
if driver.provider == GraphProvider.NEPTUNE:
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
'episode_content',
[
{
'uuid': self.uuid,
'group_id': self.group_id,
'source': self.source.value,
'content': self.content,
'source_description': self.source_description,
}
],
)
result = await driver.execute_query( result = await driver.execute_query(
EPISODIC_NODE_SAVE, get_episode_node_save_query(driver.provider),
uuid=self.uuid, uuid=self.uuid,
name=self.name, name=self.name,
group_id=self.group_id, group_id=self.group_id,
@ -213,7 +231,11 @@ class EpisodicNode(Node):
MATCH (e:Episodic {uuid: $uuid}) MATCH (e:Episodic {uuid: $uuid})
RETURN RETURN
""" """
+ EPISODIC_NODE_RETURN, + (
EPISODIC_NODE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else EPISODIC_NODE_RETURN
),
uuid=uuid, uuid=uuid,
routing_='r', routing_='r',
) )
@ -233,7 +255,11 @@ class EpisodicNode(Node):
WHERE e.uuid IN $uuids WHERE e.uuid IN $uuids
RETURN DISTINCT RETURN DISTINCT
""" """
+ EPISODIC_NODE_RETURN, + (
EPISODIC_NODE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else EPISODIC_NODE_RETURN
),
uuids=uuids, uuids=uuids,
routing_='r', routing_='r',
) )
@ -262,7 +288,11 @@ class EpisodicNode(Node):
+ """ + """
RETURN DISTINCT RETURN DISTINCT
""" """
+ EPISODIC_NODE_RETURN + (
EPISODIC_NODE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else EPISODIC_NODE_RETURN
)
+ """ + """
ORDER BY uuid DESC ORDER BY uuid DESC
""" """
@ -284,7 +314,11 @@ class EpisodicNode(Node):
MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid}) MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid})
RETURN DISTINCT RETURN DISTINCT
""" """
+ EPISODIC_NODE_RETURN, + (
EPISODIC_NODE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else EPISODIC_NODE_RETURN
),
entity_node_uuid=entity_node_uuid, entity_node_uuid=entity_node_uuid,
routing_='r', routing_='r',
) )
@ -311,11 +345,18 @@ class EntityNode(Node):
return self.name_embedding return self.name_embedding
async def load_name_embedding(self, driver: GraphDriver): async def load_name_embedding(self, driver: GraphDriver):
records, _, _ = await driver.execute_query( if driver.provider == GraphProvider.NEPTUNE:
query: LiteralString = """
MATCH (n:Entity {uuid: $uuid})
RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding
""" """
MATCH (n:Entity {uuid: $uuid}) else:
RETURN n.name_embedding AS name_embedding query: LiteralString = """
""", MATCH (n:Entity {uuid: $uuid})
RETURN n.name_embedding AS name_embedding
"""
records, _, _ = await driver.execute_query(
query,
uuid=self.uuid, uuid=self.uuid,
routing_='r', routing_='r',
) )
@ -336,6 +377,9 @@ class EntityNode(Node):
} }
entity_data.update(self.attributes or {}) entity_data.update(self.attributes or {})
if driver.provider == GraphProvider.NEPTUNE:
driver.save_to_aoss('node_name_and_summary', [entity_data]) # pyright: ignore reportAttributeAccessIssue
labels = ':'.join(self.labels + ['Entity']) labels = ':'.join(self.labels + ['Entity'])
result = await driver.execute_query( result = await driver.execute_query(
@ -433,8 +477,13 @@ class CommunityNode(Node):
summary: str = Field(description='region summary of member nodes', default_factory=str) summary: str = Field(description='region summary of member nodes', default_factory=str)
async def save(self, driver: GraphDriver): async def save(self, driver: GraphDriver):
if driver.provider == GraphProvider.NEPTUNE:
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
'community_name',
[{'name': self.name, 'uuid': self.uuid, 'group_id': self.group_id}],
)
result = await driver.execute_query( result = await driver.execute_query(
get_community_node_save_query(driver.provider), get_community_node_save_query(driver.provider), # type: ignore
uuid=self.uuid, uuid=self.uuid,
name=self.name, name=self.name,
group_id=self.group_id, group_id=self.group_id,
@ -457,11 +506,19 @@ class CommunityNode(Node):
return self.name_embedding return self.name_embedding
async def load_name_embedding(self, driver: GraphDriver): async def load_name_embedding(self, driver: GraphDriver):
records, _, _ = await driver.execute_query( if driver.provider == GraphProvider.NEPTUNE:
query: LiteralString = """
MATCH (c:Community {uuid: $uuid})
RETURN [x IN split(c.name_embedding, ",") | toFloat(x)] as name_embedding
""" """
else:
query: LiteralString = """
MATCH (c:Community {uuid: $uuid}) MATCH (c:Community {uuid: $uuid})
RETURN c.name_embedding AS name_embedding RETURN c.name_embedding AS name_embedding
""", """
records, _, _ = await driver.execute_query(
query,
uuid=self.uuid, uuid=self.uuid,
routing_='r', routing_='r',
) )
@ -478,7 +535,11 @@ class CommunityNode(Node):
MATCH (n:Community {uuid: $uuid}) MATCH (n:Community {uuid: $uuid})
RETURN RETURN
""" """
+ COMMUNITY_NODE_RETURN, + (
COMMUNITY_NODE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else COMMUNITY_NODE_RETURN
),
uuid=uuid, uuid=uuid,
routing_='r', routing_='r',
) )
@ -498,7 +559,11 @@ class CommunityNode(Node):
WHERE n.uuid IN $uuids WHERE n.uuid IN $uuids
RETURN RETURN
""" """
+ COMMUNITY_NODE_RETURN, + (
COMMUNITY_NODE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else COMMUNITY_NODE_RETURN
),
uuids=uuids, uuids=uuids,
routing_='r', routing_='r',
) )
@ -527,7 +592,11 @@ class CommunityNode(Node):
+ """ + """
RETURN RETURN
""" """
+ COMMUNITY_NODE_RETURN + (
COMMUNITY_NODE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else COMMUNITY_NODE_RETURN
)
+ """ + """
ORDER BY n.uuid DESC ORDER BY n.uuid DESC
""" """

File diff suppressed because it is too large Load diff

View file

@ -32,8 +32,8 @@ from graphiti_core.models.edges.edge_db_queries import (
get_entity_edge_save_bulk_query, get_entity_edge_save_bulk_query,
) )
from graphiti_core.models.nodes.node_db_queries import ( from graphiti_core.models.nodes.node_db_queries import (
EPISODIC_NODE_SAVE_BULK,
get_entity_node_save_bulk_query, get_entity_node_save_bulk_query,
get_episode_node_save_bulk_query,
) )
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
from graphiti_core.utils.maintenance.edge_operations import ( from graphiti_core.utils.maintenance.edge_operations import (
@ -155,7 +155,7 @@ async def add_nodes_and_edges_bulk_tx(
edge_data.update(edge.attributes or {}) edge_data.update(edge.attributes or {})
edges.append(edge_data) edges.append(edge_data)
await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes) await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
entity_node_save_bulk = get_entity_node_save_bulk_query(driver.provider, nodes) entity_node_save_bulk = get_entity_node_save_bulk_query(driver.provider, nodes)
await tx.run(entity_node_save_bulk, nodes=nodes) await tx.run(entity_node_save_bulk, nodes=nodes)
await tx.run( await tx.run(

View file

@ -21,7 +21,7 @@ from time import time
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import LiteralString from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphDriver from graphiti_core.driver.driver import GraphDriver, GraphProvider
from graphiti_core.edges import ( from graphiti_core.edges import (
CommunityEdge, CommunityEdge,
EntityEdge, EntityEdge,
@ -504,23 +504,46 @@ async def resolve_extracted_edge(
async def filter_existing_duplicate_of_edges( async def filter_existing_duplicate_of_edges(
driver: GraphDriver, duplicates_node_tuples: list[tuple[EntityNode, EntityNode]] driver: GraphDriver, duplicates_node_tuples: list[tuple[EntityNode, EntityNode]]
) -> list[tuple[EntityNode, EntityNode]]: ) -> list[tuple[EntityNode, EntityNode]]:
query: LiteralString = """ if not duplicates_node_tuples:
UNWIND $duplicate_node_uuids AS duplicate_tuple return []
MATCH (n:Entity {uuid: duplicate_tuple[0]})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple[1]})
RETURN DISTINCT
n.uuid AS source_uuid,
m.uuid AS target_uuid
"""
duplicate_nodes_map = { duplicate_nodes_map = {
(source.uuid, target.uuid): (source, target) for source, target in duplicates_node_tuples (source.uuid, target.uuid): (source, target) for source, target in duplicates_node_tuples
} }
records, _, _ = await driver.execute_query( if driver.provider == GraphProvider.NEPTUNE:
query, query: LiteralString = """
duplicate_node_uuids=list(duplicate_nodes_map.keys()), UNWIND $duplicate_node_uuids AS duplicate_tuple
routing_='r', MATCH (n:Entity {uuid: duplicate_tuple.source})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple.target})
) RETURN DISTINCT
n.uuid AS source_uuid,
m.uuid AS target_uuid
"""
duplicate_nodes = [
{'source': source.uuid, 'target': target.uuid}
for source, target in duplicates_node_tuples
]
records, _, _ = await driver.execute_query(
query,
duplicate_node_uuids=duplicate_nodes,
routing_='r',
)
else:
query: LiteralString = """
UNWIND $duplicate_node_uuids AS duplicate_tuple
MATCH (n:Entity {uuid: duplicate_tuple[0]})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple[1]})
RETURN DISTINCT
n.uuid AS source_uuid,
m.uuid AS target_uuid
"""
records, _, _ = await driver.execute_query(
query,
duplicate_node_uuids=list(duplicate_nodes_map.keys()),
routing_='r',
)
# Remove duplicates that already have the IS_DUPLICATE_OF edge # Remove duplicates that already have the IS_DUPLICATE_OF edge
for record in records: for record in records:

View file

@ -19,10 +19,13 @@ from datetime import datetime
from typing_extensions import LiteralString from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphDriver from graphiti_core.driver.driver import GraphDriver, GraphProvider
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
from graphiti_core.helpers import semaphore_gather from graphiti_core.helpers import semaphore_gather
from graphiti_core.models.nodes.node_db_queries import EPISODIC_NODE_RETURN from graphiti_core.models.nodes.node_db_queries import (
EPISODIC_NODE_RETURN,
EPISODIC_NODE_RETURN_NEPTUNE,
)
from graphiti_core.nodes import EpisodeType, EpisodicNode, get_episodic_node_from_record from graphiti_core.nodes import EpisodeType, EpisodicNode, get_episodic_node_from_record
EPISODE_WINDOW_LEN = 3 EPISODE_WINDOW_LEN = 3
@ -31,6 +34,8 @@ logger = logging.getLogger(__name__)
async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bool = False): async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bool = False):
if driver.provider == GraphProvider.NEPTUNE:
return # Neptune does not need indexes built
if delete_existing: if delete_existing:
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
""" """
@ -71,7 +76,7 @@ async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None):
async def delete_group_ids(tx): async def delete_group_ids(tx):
await tx.run( await tx.run(
'MATCH (n:Entity|Episodic|Community) WHERE n.group_id IN $group_ids DETACH DELETE n', 'MATCH (n) WHERE (n:Entity OR n:Episodic OR n:Community) AND n.group_id IN $group_ids DETACH DELETE n',
group_ids=group_ids, group_ids=group_ids,
) )
@ -117,7 +122,11 @@ async def retrieve_episodes(
+ """ + """
RETURN RETURN
""" """
+ EPISODIC_NODE_RETURN + (
EPISODIC_NODE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else EPISODIC_NODE_RETURN
)
+ """ + """
ORDER BY e.valid_at DESC ORDER BY e.valid_at DESC
LIMIT $num_episodes LIMIT $num_episodes

View file

@ -18,7 +18,7 @@ dependencies = [
"tenacity>=9.0.0", "tenacity>=9.0.0",
"numpy>=1.0.0", "numpy>=1.0.0",
"python-dotenv>=1.0.1", "python-dotenv>=1.0.1",
"posthog>=3.0.0", "posthog>=3.0.0"
] ]
[project.urls] [project.urls]
@ -32,6 +32,7 @@ google-genai = ["google-genai>=1.8.0"]
falkordb = ["falkordb>=1.1.2,<2.0.0"] falkordb = ["falkordb>=1.1.2,<2.0.0"]
voyageai = ["voyageai>=0.2.3"] voyageai = ["voyageai>=0.2.3"]
sentence-transformers = ["sentence-transformers>=3.2.1"] sentence-transformers = ["sentence-transformers>=3.2.1"]
neptune = ["langchain-aws>=0.2.29", "opensearch-py>=3.0.0", "boto3>=1.39.16"]
dev = [ dev = [
"pyright>=1.1.380", "pyright>=1.1.380",
"groq>=0.2.0", "groq>=0.2.0",

View file

@ -20,12 +20,14 @@ import pytest
from dotenv import load_dotenv from dotenv import load_dotenv
from graphiti_core.driver.driver import GraphDriver from graphiti_core.driver.driver import GraphDriver
from graphiti_core.driver.neptune_driver import NeptuneDriver
from graphiti_core.helpers import lucene_sanitize from graphiti_core.helpers import lucene_sanitize
load_dotenv() load_dotenv()
HAS_NEO4J = False HAS_NEO4J = False
HAS_FALKORDB = False HAS_FALKORDB = False
HAS_NEPTUNE = False
if os.getenv('DISABLE_NEO4J') is None: if os.getenv('DISABLE_NEO4J') is None:
try: try:
from graphiti_core.driver.neo4j_driver import Neo4jDriver from graphiti_core.driver.neo4j_driver import Neo4jDriver
@ -42,6 +44,14 @@ if os.getenv('DISABLE_FALKORDB') is None:
except ImportError: except ImportError:
pass pass
if os.getenv('DISABLE_NEPTUNE') is None:
try:
from graphiti_core.driver.neptune_driver import NeptuneDriver
HAS_NEPTUNE = True
except ImportError:
pass
NEO4J_URI = os.getenv('NEO4J_URI', 'bolt://localhost:7687') NEO4J_URI = os.getenv('NEO4J_URI', 'bolt://localhost:7687')
NEO4J_USER = os.getenv('NEO4J_USER', 'neo4j') NEO4J_USER = os.getenv('NEO4J_USER', 'neo4j')
NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD', 'test') NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD', 'test')
@ -51,6 +61,10 @@ FALKORDB_PORT = os.getenv('FALKORDB_PORT', '6379')
FALKORDB_USER = os.getenv('FALKORDB_USER', None) FALKORDB_USER = os.getenv('FALKORDB_USER', None)
FALKORDB_PASSWORD = os.getenv('FALKORDB_PASSWORD', None) FALKORDB_PASSWORD = os.getenv('FALKORDB_PASSWORD', None)
NEPTUNE_HOST = os.getenv('NEPTUNE_HOST', 'localhost')
NEPTUNE_PORT = os.getenv('NEPTUNE_PORT', 8182)
AOSS_HOST = os.getenv('AOSS_HOST', None)
def get_driver(driver_name: str) -> GraphDriver: def get_driver(driver_name: str) -> GraphDriver:
if driver_name == 'neo4j': if driver_name == 'neo4j':
@ -66,6 +80,12 @@ def get_driver(driver_name: str) -> GraphDriver:
username=FALKORDB_USER, username=FALKORDB_USER,
password=FALKORDB_PASSWORD, password=FALKORDB_PASSWORD,
) )
elif driver_name == 'neptune':
return NeptuneDriver(
host=NEPTUNE_HOST,
port=int(NEPTUNE_PORT),
aoss_host=AOSS_HOST,
)
else: else:
raise ValueError(f'Driver {driver_name} not available') raise ValueError(f'Driver {driver_name} not available')
@ -75,6 +95,8 @@ if HAS_NEO4J:
drivers.append('neo4j') drivers.append('neo4j')
if HAS_FALKORDB: if HAS_FALKORDB:
drivers.append('falkordb') drivers.append('falkordb')
if HAS_NEPTUNE:
drivers.append('neptune')
def test_lucene_sanitize(): def test_lucene_sanitize():

4130
uv.lock generated

File diff suppressed because it is too large Load diff