Amazon Neptune Support (#793)
* Rebased Neptune changes based on significant rework done * Updated the README documentation * Fixed linting and formatting * Update README.md Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> * Update graphiti_core/driver/neptune_driver.py Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> * Update README.md Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> * Addressed feedback from code review * Updated the README documentation for clarity * Updated the README and neptune_driver based on PR feedback * Update node_db_queries.py --------- Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> Co-authored-by: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com>
This commit is contained in:
parent
9c1e1ad7ef
commit
ef56dc779a
15 changed files with 3805 additions and 2460 deletions
40
README.md
40
README.md
|
|
@ -105,7 +105,7 @@ Graphiti is specifically designed to address the challenges of dynamic and frequ
|
|||
Requirements:
|
||||
|
||||
- Python 3.10 or higher
|
||||
- Neo4j 5.26 / FalkorDB 1.1.2 or higher (serves as the embeddings storage backend)
|
||||
- Neo4j 5.26 / FalkorDB 1.1.2 / Amazon Neptune Database Cluster or Neptune Analytics Graph + Amazon OpenSearch Serverless collection (serves as the full text search backend)
|
||||
- OpenAI API key (Graphiti defaults to OpenAI for LLM inference and embedding)
|
||||
|
||||
> [!IMPORTANT]
|
||||
|
|
@ -148,6 +148,17 @@ pip install graphiti-core[falkordb]
|
|||
uv add graphiti-core[falkordb]
|
||||
```
|
||||
|
||||
### Installing with Amazon Neptune Support
|
||||
|
||||
If you plan to use Amazon Neptune as your graph database backend, install with the Amazon Neptune extra:
|
||||
|
||||
```bash
|
||||
pip install graphiti-core[neptune]
|
||||
|
||||
# or with uv
|
||||
uv add graphiti-core[neptune]
|
||||
```
|
||||
|
||||
### You can also install optional LLM providers as extras:
|
||||
|
||||
```bash
|
||||
|
|
@ -165,6 +176,9 @@ pip install graphiti-core[anthropic,groq,google-genai]
|
|||
|
||||
# Install with FalkorDB and LLM providers
|
||||
pip install graphiti-core[falkordb,anthropic,google-genai]
|
||||
|
||||
# Install with Amazon Neptune
|
||||
pip install graphiti-core[neptune]
|
||||
```
|
||||
|
||||
## Default to Low Concurrency; LLM Provider 429 Rate Limit Errors
|
||||
|
|
@ -184,7 +198,7 @@ If your LLM provider allows higher throughput, you can increase `SEMAPHORE_LIMIT
|
|||
|
||||
For a complete working example, see the [Quickstart Example](./examples/quickstart/README.md) in the examples directory. The quickstart demonstrates:
|
||||
|
||||
1. Connecting to a Neo4j or FalkorDB database
|
||||
1. Connecting to a Neo4j, Amazon Neptune, or FalkorDB database
|
||||
2. Initializing Graphiti indices and constraints
|
||||
3. Adding episodes to the graph (both text and structured JSON)
|
||||
4. Searching for relationships (edges) using hybrid search
|
||||
|
|
@ -267,6 +281,26 @@ driver = FalkorDriver(
|
|||
graphiti = Graphiti(graph_driver=driver)
|
||||
```
|
||||
|
||||
#### Amazon Neptune
|
||||
|
||||
```python
|
||||
from graphiti_core import Graphiti
|
||||
from graphiti_core.driver.neptune_driver import NeptuneDriver
|
||||
|
||||
# Create a FalkorDB driver with custom database name
|
||||
driver = NeptuneDriver(
|
||||
host=<NEPTUNE ENDPOINT>,
|
||||
aoss_host=<Amazon OpenSearch Serverless Host>,
|
||||
port=<PORT> # Optional, defaults to 8182,
|
||||
aoss_port=<PORT> # Optional, defaults to 443
|
||||
)
|
||||
|
||||
driver = NeptuneDriver(host=neptune_uri, aoss_host=aoss_host, port=neptune_port)
|
||||
|
||||
# Pass the driver to Graphiti
|
||||
graphiti = Graphiti(graph_driver=driver)
|
||||
```
|
||||
|
||||
|
||||
### Performance Configuration
|
||||
|
||||
|
|
@ -458,7 +492,7 @@ When you initialize a Graphiti instance, we collect:
|
|||
- **Graphiti version**: The version you're using
|
||||
- **Configuration choices**:
|
||||
- LLM provider type (OpenAI, Azure, Anthropic, etc.)
|
||||
- Database backend (Neo4j, FalkorDB)
|
||||
- Database backend (Neo4j, FalkorDB, Amazon Neptune Database or Neptune Analytics)
|
||||
- Embedder provider (OpenAI, Azure, Voyage, etc.)
|
||||
|
||||
### What We Don't Collect
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@ This example demonstrates the basic functionality of Graphiti, including:
|
|||
- A local DBMS created and started in Neo4j Desktop
|
||||
- **For FalkorDB**:
|
||||
- FalkorDB server running (see [FalkorDB documentation](https://falkordb.com/docs/) for setup)
|
||||
- **For Amazon Neptune**:
|
||||
- Amazon server running (see [Amazon Neptune documentation](https://aws.amazon.com/neptune/developer-resources/) for setup)
|
||||
|
||||
|
||||
## Setup Instructions
|
||||
|
|
@ -42,9 +44,19 @@ export NEO4J_PASSWORD=password
|
|||
# Optional FalkorDB connection parameters (defaults shown)
|
||||
export FALKORDB_URI=falkor://localhost:6379
|
||||
|
||||
# Optional Amazon Neptune connection parameters
|
||||
NEPTUNE_HOST=your_neptune_host
|
||||
NEPTUNE_PORT=your_port_or_8182
|
||||
AOSS_HOST=your_aoss_host
|
||||
AOSS_PORT=your_port_or_443
|
||||
|
||||
# To use a different database, modify the driver constructor in the script
|
||||
```
|
||||
|
||||
TIP: For Amazon Neptune host string please use the following formats
|
||||
* For Neptune Database: `neptune-db://<cluster endpoint>`
|
||||
* For Neptune Analytics: `neptune-graph://<graph identifier>`
|
||||
|
||||
3. Run the example:
|
||||
|
||||
```bash
|
||||
|
|
@ -52,11 +64,14 @@ python quickstart_neo4j.py
|
|||
|
||||
# For FalkorDB
|
||||
python quickstart_falkordb.py
|
||||
|
||||
# For Amazon Neptune
|
||||
python quickstart_neptune.py
|
||||
```
|
||||
|
||||
## What This Example Demonstrates
|
||||
|
||||
- **Graph Initialization**: Setting up the Graphiti indices and constraints in Neo4j or FalkorDB
|
||||
- **Graph Initialization**: Setting up the Graphiti indices and constraints in Neo4j, Amazon Neptune, or FalkorDB
|
||||
- **Adding Episodes**: Adding text content that will be analyzed and converted into knowledge graph nodes and edges
|
||||
- **Edge Search Functionality**: Performing hybrid searches that combine semantic similarity and BM25 retrieval to find relationships (edges)
|
||||
- **Graph-Aware Search**: Using the source node UUID from the top search result to rerank additional search results based on graph distance
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ logger = logging.getLogger(__name__)
|
|||
class GraphProvider(Enum):
|
||||
NEO4J = 'neo4j'
|
||||
FALKORDB = 'falkordb'
|
||||
NEPTUNE = 'neptune'
|
||||
|
||||
|
||||
class GraphDriverSession(ABC):
|
||||
|
|
|
|||
299
graphiti_core/driver/neptune_driver.py
Normal file
299
graphiti_core/driver/neptune_driver.py
Normal file
|
|
@ -0,0 +1,299 @@
|
|||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import logging
|
||||
from collections.abc import Coroutine
|
||||
from typing import Any
|
||||
|
||||
import boto3
|
||||
from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph
|
||||
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_SIZE = 10
|
||||
|
||||
aoss_indices = [
|
||||
{
|
||||
'index_name': 'node_name_and_summary',
|
||||
'body': {
|
||||
'mappings': {
|
||||
'properties': {
|
||||
'uuid': {'type': 'keyword'},
|
||||
'name': {'type': 'text'},
|
||||
'summary': {'type': 'text'},
|
||||
'group_id': {'type': 'text'},
|
||||
}
|
||||
}
|
||||
},
|
||||
'query': {
|
||||
'query': {'multi_match': {'query': '', 'fields': ['name', 'summary', 'group_id']}},
|
||||
'size': DEFAULT_SIZE,
|
||||
},
|
||||
},
|
||||
{
|
||||
'index_name': 'community_name',
|
||||
'body': {
|
||||
'mappings': {
|
||||
'properties': {
|
||||
'uuid': {'type': 'keyword'},
|
||||
'name': {'type': 'text'},
|
||||
'group_id': {'type': 'text'},
|
||||
}
|
||||
}
|
||||
},
|
||||
'query': {
|
||||
'query': {'multi_match': {'query': '', 'fields': ['name', 'group_id']}},
|
||||
'size': DEFAULT_SIZE,
|
||||
},
|
||||
},
|
||||
{
|
||||
'index_name': 'episode_content',
|
||||
'body': {
|
||||
'mappings': {
|
||||
'properties': {
|
||||
'uuid': {'type': 'keyword'},
|
||||
'content': {'type': 'text'},
|
||||
'source': {'type': 'text'},
|
||||
'source_description': {'type': 'text'},
|
||||
'group_id': {'type': 'text'},
|
||||
}
|
||||
}
|
||||
},
|
||||
'query': {
|
||||
'query': {
|
||||
'multi_match': {
|
||||
'query': '',
|
||||
'fields': ['content', 'source', 'source_description', 'group_id'],
|
||||
}
|
||||
},
|
||||
'size': DEFAULT_SIZE,
|
||||
},
|
||||
},
|
||||
{
|
||||
'index_name': 'edge_name_and_fact',
|
||||
'body': {
|
||||
'mappings': {
|
||||
'properties': {
|
||||
'uuid': {'type': 'keyword'},
|
||||
'name': {'type': 'text'},
|
||||
'fact': {'type': 'text'},
|
||||
'group_id': {'type': 'text'},
|
||||
}
|
||||
}
|
||||
},
|
||||
'query': {
|
||||
'query': {'multi_match': {'query': '', 'fields': ['name', 'fact', 'group_id']}},
|
||||
'size': DEFAULT_SIZE,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class NeptuneDriver(GraphDriver):
|
||||
provider: GraphProvider = GraphProvider.NEPTUNE
|
||||
|
||||
def __init__(self, host: str, aoss_host: str, port: int = 8182, aoss_port: int = 443):
|
||||
"""This initializes a NeptuneDriver for use with Neptune as a backend
|
||||
|
||||
Args:
|
||||
host (str): The Neptune Database or Neptune Analytics host
|
||||
aoss_host (str): The OpenSearch host value
|
||||
port (int, optional): The Neptune Database port, ignored for Neptune Analytics. Defaults to 8182.
|
||||
aoss_port (int, optional): The OpenSearch port. Defaults to 443.
|
||||
"""
|
||||
if not host:
|
||||
raise ValueError('You must provide an endpoint to create a NeptuneDriver')
|
||||
|
||||
if host.startswith('neptune-db://'):
|
||||
# This is a Neptune Database Cluster
|
||||
endpoint = host.replace('neptune-db://', '')
|
||||
self.client = NeptuneGraph(endpoint, port)
|
||||
logger.debug('Creating Neptune Database session for %s', host)
|
||||
elif host.startswith('neptune-graph://'):
|
||||
# This is a Neptune Analytics Graph
|
||||
graphId = host.replace('neptune-graph://', '')
|
||||
self.client = NeptuneAnalyticsGraph(graphId)
|
||||
logger.debug('Creating Neptune Graph session for %s', host)
|
||||
else:
|
||||
raise ValueError(
|
||||
'You must provide an endpoint to create a NeptuneDriver as either neptune-db://<endpoint> or neptune-graph://<graphid>'
|
||||
)
|
||||
|
||||
if not aoss_host:
|
||||
raise ValueError('You must provide an AOSS endpoint to create an OpenSearch driver.')
|
||||
|
||||
session = boto3.Session()
|
||||
self.aoss_client = OpenSearch(
|
||||
hosts=[{'host': aoss_host, 'port': aoss_port}],
|
||||
http_auth=Urllib3AWSV4SignerAuth(
|
||||
session.get_credentials(), session.region_name, 'aoss'
|
||||
),
|
||||
use_ssl=True,
|
||||
verify_certs=True,
|
||||
connection_class=Urllib3HttpConnection,
|
||||
pool_maxsize=20,
|
||||
)
|
||||
|
||||
def _sanitize_parameters(self, query, params: dict):
|
||||
if isinstance(query, list):
|
||||
queries = []
|
||||
for q in query:
|
||||
queries.append(self._sanitize_parameters(q, params))
|
||||
return queries
|
||||
else:
|
||||
for k, v in params.items():
|
||||
if isinstance(v, datetime.datetime):
|
||||
params[k] = v.isoformat()
|
||||
elif isinstance(v, list):
|
||||
# Handle lists that might contain datetime objects
|
||||
for i, item in enumerate(v):
|
||||
if isinstance(item, datetime.datetime):
|
||||
v[i] = item.isoformat()
|
||||
query = str(query).replace(f'${k}', f'datetime(${k})')
|
||||
if isinstance(item, dict):
|
||||
query = self._sanitize_parameters(query, v[i])
|
||||
|
||||
# If the list contains datetime objects, we need to wrap each element with datetime()
|
||||
if any(isinstance(item, str) and 'T' in item for item in v):
|
||||
# Create a new list expression with datetime() wrapped around each element
|
||||
datetime_list = (
|
||||
'['
|
||||
+ ', '.join(
|
||||
f'datetime("{item}")'
|
||||
if isinstance(item, str) and 'T' in item
|
||||
else repr(item)
|
||||
for item in v
|
||||
)
|
||||
+ ']'
|
||||
)
|
||||
query = str(query).replace(f'${k}', datetime_list)
|
||||
elif isinstance(v, dict):
|
||||
query = self._sanitize_parameters(query, v)
|
||||
return query
|
||||
|
||||
async def execute_query(
|
||||
self, cypher_query_, **kwargs: Any
|
||||
) -> tuple[dict[str, Any], None, None]:
|
||||
params = dict(kwargs)
|
||||
if isinstance(cypher_query_, list):
|
||||
for q in cypher_query_:
|
||||
result, _, _ = self._run_query(q[0], q[1])
|
||||
return result, None, None
|
||||
else:
|
||||
return self._run_query(cypher_query_, params)
|
||||
|
||||
def _run_query(self, cypher_query_, params):
|
||||
cypher_query_ = str(self._sanitize_parameters(cypher_query_, params))
|
||||
try:
|
||||
result = self.client.query(cypher_query_, params=params)
|
||||
except Exception as e:
|
||||
logger.error('Query: %s', cypher_query_)
|
||||
logger.error('Parameters: %s', params)
|
||||
logger.error('Error executing query: %s', e)
|
||||
raise e
|
||||
|
||||
return result, None, None
|
||||
|
||||
def session(self, database: str | None = None) -> GraphDriverSession:
|
||||
return NeptuneDriverSession(driver=self)
|
||||
|
||||
async def close(self) -> None:
|
||||
return self.client.client.close()
|
||||
|
||||
async def _delete_all_data(self) -> Any:
|
||||
return await self.execute_query('MATCH (n) DETACH DELETE n')
|
||||
|
||||
def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
|
||||
return self.delete_all_indexes_impl()
|
||||
|
||||
async def delete_all_indexes_impl(self) -> Coroutine[Any, Any, Any]:
|
||||
# No matter what happens above, always return True
|
||||
return self.delete_aoss_indices()
|
||||
|
||||
async def create_aoss_indices(self):
|
||||
for index in aoss_indices:
|
||||
index_name = index['index_name']
|
||||
client = self.aoss_client
|
||||
if not client.indices.exists(index=index_name):
|
||||
client.indices.create(index=index_name, body=index['body'])
|
||||
# Sleep for 1 minute to let the index creation complete
|
||||
await asyncio.sleep(60)
|
||||
|
||||
async def delete_aoss_indices(self):
|
||||
for index in aoss_indices:
|
||||
index_name = index['index_name']
|
||||
client = self.aoss_client
|
||||
if client.indices.exists(index=index_name):
|
||||
client.indices.delete(index=index_name)
|
||||
|
||||
def run_aoss_query(self, name: str, query_text: str, limit: int = 10) -> dict[str, Any]:
|
||||
for index in aoss_indices:
|
||||
if name.lower() == index['index_name']:
|
||||
index['query']['query']['multi_match']['query'] = query_text
|
||||
query = {'size': limit, 'query': index['query']}
|
||||
resp = self.aoss_client.search(body=query['query'], index=index['index_name'])
|
||||
return resp
|
||||
return {}
|
||||
|
||||
def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
||||
for index in aoss_indices:
|
||||
if name.lower() == index['index_name']:
|
||||
to_index = []
|
||||
for d in data:
|
||||
item = {'_index': name}
|
||||
for p in index['body']['mappings']['properties']:
|
||||
item[p] = d[p]
|
||||
to_index.append(item)
|
||||
success, failed = helpers.bulk(self.aoss_client, to_index, stats_only=True)
|
||||
if failed > 0:
|
||||
return success
|
||||
else:
|
||||
return 0
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
class NeptuneDriverSession(GraphDriverSession):
|
||||
def __init__(self, driver: NeptuneDriver): # type: ignore[reportUnknownArgumentType]
|
||||
self.driver = driver
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
# No cleanup needed for Neptune, but method must exist
|
||||
pass
|
||||
|
||||
async def close(self):
|
||||
# No explicit close needed for Neptune, but method must exist
|
||||
pass
|
||||
|
||||
async def execute_write(self, func, *args, **kwargs):
|
||||
# Directly await the provided async function with `self` as the transaction/session
|
||||
return await func(self, *args, **kwargs)
|
||||
|
||||
async def run(self, query: str | list, **kwargs: Any) -> Any:
|
||||
if isinstance(query, list):
|
||||
res = None
|
||||
for q in query:
|
||||
res = await self.driver.execute_query(q, **kwargs)
|
||||
return res
|
||||
else:
|
||||
return await self.driver.execute_query(str(query), **kwargs)
|
||||
|
|
@ -24,13 +24,14 @@ from uuid import uuid4
|
|||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphProvider
|
||||
from graphiti_core.embedder import EmbedderClient
|
||||
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
|
||||
from graphiti_core.helpers import parse_db_date
|
||||
from graphiti_core.models.edges.edge_db_queries import (
|
||||
COMMUNITY_EDGE_RETURN,
|
||||
ENTITY_EDGE_RETURN,
|
||||
ENTITY_EDGE_RETURN_NEPTUNE,
|
||||
EPISODIC_EDGE_RETURN,
|
||||
EPISODIC_EDGE_SAVE,
|
||||
get_community_edge_save_query,
|
||||
|
|
@ -214,11 +215,19 @@ class EntityEdge(Edge):
|
|||
return self.fact_embedding
|
||||
|
||||
async def load_fact_embedding(self, driver: GraphDriver):
|
||||
records, _, _ = await driver.execute_query(
|
||||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
query: LiteralString = """
|
||||
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
||||
RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
|
||||
"""
|
||||
else:
|
||||
query: LiteralString = """
|
||||
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
||||
RETURN e.fact_embedding AS fact_embedding
|
||||
""",
|
||||
"""
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
uuid=self.uuid,
|
||||
routing_='r',
|
||||
)
|
||||
|
|
@ -246,6 +255,9 @@ class EntityEdge(Edge):
|
|||
|
||||
edge_data.update(self.attributes or {})
|
||||
|
||||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
driver.save_to_aoss('edge_name_and_fact', [edge_data]) # pyright: ignore reportAttributeAccessIssue
|
||||
|
||||
result = await driver.execute_query(
|
||||
get_entity_edge_save_query(driver.provider),
|
||||
edge_data=edge_data,
|
||||
|
|
@ -262,7 +274,11 @@ class EntityEdge(Edge):
|
|||
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
||||
RETURN
|
||||
"""
|
||||
+ ENTITY_EDGE_RETURN,
|
||||
+ (
|
||||
ENTITY_EDGE_RETURN_NEPTUNE
|
||||
if driver.provider == GraphProvider.NEPTUNE
|
||||
else ENTITY_EDGE_RETURN
|
||||
),
|
||||
uuid=uuid,
|
||||
routing_='r',
|
||||
)
|
||||
|
|
@ -284,7 +300,11 @@ class EntityEdge(Edge):
|
|||
WHERE e.uuid IN $uuids
|
||||
RETURN
|
||||
"""
|
||||
+ ENTITY_EDGE_RETURN,
|
||||
+ (
|
||||
ENTITY_EDGE_RETURN_NEPTUNE
|
||||
if driver.provider == GraphProvider.NEPTUNE
|
||||
else ENTITY_EDGE_RETURN
|
||||
),
|
||||
uuids=uuids,
|
||||
routing_='r',
|
||||
)
|
||||
|
|
@ -321,7 +341,11 @@ class EntityEdge(Edge):
|
|||
+ """
|
||||
RETURN
|
||||
"""
|
||||
+ ENTITY_EDGE_RETURN
|
||||
+ (
|
||||
ENTITY_EDGE_RETURN_NEPTUNE
|
||||
if driver.provider == GraphProvider.NEPTUNE
|
||||
else ENTITY_EDGE_RETURN
|
||||
)
|
||||
+ with_embeddings_query
|
||||
+ """
|
||||
ORDER BY e.uuid DESC
|
||||
|
|
@ -346,7 +370,11 @@ class EntityEdge(Edge):
|
|||
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
|
||||
RETURN
|
||||
"""
|
||||
+ ENTITY_EDGE_RETURN,
|
||||
+ (
|
||||
ENTITY_EDGE_RETURN_NEPTUNE
|
||||
if driver.provider == GraphProvider.NEPTUNE
|
||||
else ENTITY_EDGE_RETURN
|
||||
),
|
||||
node_uuid=node_uuid,
|
||||
routing_='r',
|
||||
)
|
||||
|
|
|
|||
|
|
@ -43,47 +43,70 @@ EPISODIC_EDGE_RETURN = """
|
|||
|
||||
|
||||
def get_entity_edge_save_query(provider: GraphProvider) -> str:
|
||||
if provider == GraphProvider.FALKORDB:
|
||||
return """
|
||||
MATCH (source:Entity {uuid: $edge_data.source_uuid})
|
||||
MATCH (target:Entity {uuid: $edge_data.target_uuid})
|
||||
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
|
||||
SET e = $edge_data
|
||||
RETURN e.uuid AS uuid
|
||||
"""
|
||||
|
||||
return """
|
||||
MATCH (source:Entity {uuid: $edge_data.source_uuid})
|
||||
MATCH (target:Entity {uuid: $edge_data.target_uuid})
|
||||
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
|
||||
SET e = $edge_data
|
||||
WITH e CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", $edge_data.fact_embedding)
|
||||
RETURN e.uuid AS uuid
|
||||
"""
|
||||
match provider:
|
||||
case GraphProvider.FALKORDB:
|
||||
return """
|
||||
MATCH (source:Entity {uuid: $edge_data.source_uuid})
|
||||
MATCH (target:Entity {uuid: $edge_data.target_uuid})
|
||||
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
|
||||
SET e = $edge_data
|
||||
RETURN e.uuid AS uuid
|
||||
"""
|
||||
case GraphProvider.NEPTUNE:
|
||||
return """
|
||||
MATCH (source:Entity {uuid: $edge_data.source_uuid})
|
||||
MATCH (target:Entity {uuid: $edge_data.target_uuid})
|
||||
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
|
||||
SET e = removeKeyFromMap(removeKeyFromMap($edge_data, "fact_embedding"), "episodes")
|
||||
SET e.fact_embedding = join([x IN coalesce($edge_data.fact_embedding, []) | toString(x) ], ",")
|
||||
SET e.episodes = join($edge_data.episodes, ",")
|
||||
RETURN $edge_data.uuid AS uuid
|
||||
"""
|
||||
case _: # Neo4j
|
||||
return """
|
||||
MATCH (source:Entity {uuid: $edge_data.source_uuid})
|
||||
MATCH (target:Entity {uuid: $edge_data.target_uuid})
|
||||
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
|
||||
SET e = $edge_data
|
||||
WITH e CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", $edge_data.fact_embedding)
|
||||
RETURN e.uuid AS uuid
|
||||
"""
|
||||
|
||||
|
||||
def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str:
|
||||
if provider == GraphProvider.FALKORDB:
|
||||
return """
|
||||
UNWIND $entity_edges AS edge
|
||||
MATCH (source:Entity {uuid: edge.source_node_uuid})
|
||||
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
||||
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
|
||||
SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
|
||||
created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at, fact_embedding: vecf32(edge.fact_embedding)}
|
||||
WITH r, edge
|
||||
RETURN edge.uuid AS uuid
|
||||
"""
|
||||
|
||||
return """
|
||||
UNWIND $entity_edges AS edge
|
||||
MATCH (source:Entity {uuid: edge.source_node_uuid})
|
||||
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
||||
MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target)
|
||||
SET e = edge
|
||||
WITH e, edge CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", edge.fact_embedding)
|
||||
RETURN edge.uuid AS uuid
|
||||
"""
|
||||
match provider:
|
||||
case GraphProvider.FALKORDB:
|
||||
return """
|
||||
UNWIND $entity_edges AS edge
|
||||
MATCH (source:Entity {uuid: edge.source_node_uuid})
|
||||
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
||||
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
|
||||
SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
|
||||
created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at, fact_embedding: vecf32(edge.fact_embedding)}
|
||||
WITH r, edge
|
||||
RETURN edge.uuid AS uuid
|
||||
"""
|
||||
case GraphProvider.NEPTUNE:
|
||||
return """
|
||||
UNWIND $entity_edges AS edge
|
||||
MATCH (source:Entity {uuid: edge.source_node_uuid})
|
||||
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
||||
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
|
||||
SET r = removeKeyFromMap(removeKeyFromMap(edge, "fact_embedding"), "episodes")
|
||||
SET r.fact_embedding = join([x IN coalesce(edge.fact_embedding, []) | toString(x) ], ",")
|
||||
SET r.episodes = join(edge.episodes, ",")
|
||||
RETURN edge.uuid AS uuid
|
||||
"""
|
||||
case _:
|
||||
return """
|
||||
UNWIND $entity_edges AS edge
|
||||
MATCH (source:Entity {uuid: edge.source_node_uuid})
|
||||
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
||||
MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target)
|
||||
SET e = edge
|
||||
WITH e, edge CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", edge.fact_embedding)
|
||||
RETURN edge.uuid AS uuid
|
||||
"""
|
||||
|
||||
|
||||
ENTITY_EDGE_RETURN = """
|
||||
|
|
@ -101,24 +124,51 @@ ENTITY_EDGE_RETURN = """
|
|||
properties(e) AS attributes
|
||||
"""
|
||||
|
||||
ENTITY_EDGE_RETURN_NEPTUNE = """
|
||||
e.uuid AS uuid,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
e.group_id AS group_id,
|
||||
e.name AS name,
|
||||
e.fact AS fact,
|
||||
split(e.episodes, ',') AS episodes,
|
||||
e.created_at AS created_at,
|
||||
e.expired_at AS expired_at,
|
||||
e.valid_at AS valid_at,
|
||||
e.invalid_at AS invalid_at,
|
||||
properties(e) AS attributes
|
||||
"""
|
||||
|
||||
|
||||
def get_community_edge_save_query(provider: GraphProvider) -> str:
|
||||
if provider == GraphProvider.FALKORDB:
|
||||
return """
|
||||
MATCH (community:Community {uuid: $community_uuid})
|
||||
MATCH (node {uuid: $entity_uuid})
|
||||
MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
|
||||
SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
|
||||
RETURN e.uuid AS uuid
|
||||
"""
|
||||
|
||||
return """
|
||||
MATCH (community:Community {uuid: $community_uuid})
|
||||
MATCH (node:Entity | Community {uuid: $entity_uuid})
|
||||
MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
|
||||
SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
|
||||
RETURN e.uuid AS uuid
|
||||
"""
|
||||
match provider:
|
||||
case GraphProvider.FALKORDB:
|
||||
return """
|
||||
MATCH (community:Community {uuid: $community_uuid})
|
||||
MATCH (node {uuid: $entity_uuid})
|
||||
MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
|
||||
SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
|
||||
RETURN e.uuid AS uuid
|
||||
"""
|
||||
case GraphProvider.NEPTUNE:
|
||||
return """
|
||||
MATCH (community:Community {uuid: $community_uuid})
|
||||
MATCH (node {uuid: $entity_uuid})
|
||||
WHERE node:Entity OR node:Community
|
||||
MERGE (community)-[r:HAS_MEMBER {uuid: $uuid}]->(node)
|
||||
SET r.uuid= $uuid
|
||||
SET r.group_id= $group_id
|
||||
SET r.created_at= $created_at
|
||||
RETURN r.uuid AS uuid
|
||||
"""
|
||||
case _: # Neo4j
|
||||
return """
|
||||
MATCH (community:Community {uuid: $community_uuid})
|
||||
MATCH (node:Entity | Community {uuid: $entity_uuid})
|
||||
MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
|
||||
SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
|
||||
RETURN e.uuid AS uuid
|
||||
"""
|
||||
|
||||
|
||||
COMMUNITY_EDGE_RETURN = """
|
||||
|
|
|
|||
|
|
@ -18,21 +18,45 @@ from typing import Any
|
|||
|
||||
from graphiti_core.driver.driver import GraphProvider
|
||||
|
||||
EPISODIC_NODE_SAVE = """
|
||||
MERGE (n:Episodic {uuid: $uuid})
|
||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
|
||||
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
|
||||
EPISODIC_NODE_SAVE_BULK = """
|
||||
UNWIND $episodes AS episode
|
||||
MERGE (n:Episodic {uuid: episode.uuid})
|
||||
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description,
|
||||
source: episode.source, content: episode.content,
|
||||
entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
def get_episode_node_save_query(provider: GraphProvider) -> str:
|
||||
match provider:
|
||||
case GraphProvider.NEPTUNE:
|
||||
return """
|
||||
MERGE (n:Episodic {uuid: $uuid})
|
||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
|
||||
entity_edges: join([x IN coalesce($entity_edges, []) | toString(x) ], '|'), created_at: $created_at, valid_at: $valid_at}
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case _: # Neo4j and FalkorDB
|
||||
return """
|
||||
MERGE (n:Episodic {uuid: $uuid})
|
||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
|
||||
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
|
||||
|
||||
def get_episode_node_save_bulk_query(provider: GraphProvider) -> str:
|
||||
match provider:
|
||||
case GraphProvider.NEPTUNE:
|
||||
return """
|
||||
UNWIND $episodes AS episode
|
||||
MERGE (n:Episodic {uuid: episode.uuid})
|
||||
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description,
|
||||
source: episode.source, content: episode.content,
|
||||
entity_edges: join([x IN coalesce(episode.entity_edges, []) | toString(x) ], '|'), created_at: episode.created_at, valid_at: episode.valid_at}
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case _: # Neo4j and FalkorDB
|
||||
return """
|
||||
UNWIND $episodes AS episode
|
||||
MERGE (n:Episodic {uuid: episode.uuid})
|
||||
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description, source: episode.source, content: episode.content,
|
||||
entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
|
||||
|
||||
EPISODIC_NODE_RETURN = """
|
||||
e.content AS content,
|
||||
|
|
@ -46,54 +70,96 @@ EPISODIC_NODE_RETURN = """
|
|||
e.entity_edges AS entity_edges
|
||||
"""
|
||||
|
||||
EPISODIC_NODE_RETURN_NEPTUNE = """
|
||||
e.content AS content,
|
||||
e.created_at AS created_at,
|
||||
e.valid_at AS valid_at,
|
||||
e.uuid AS uuid,
|
||||
e.name AS name,
|
||||
e.group_id AS group_id,
|
||||
e.source_description AS source_description,
|
||||
e.source AS source,
|
||||
split(e.entity_edges, ",") AS entity_edges
|
||||
"""
|
||||
|
||||
|
||||
def get_entity_node_save_query(provider: GraphProvider, labels: str) -> str:
|
||||
if provider == GraphProvider.FALKORDB:
|
||||
return f"""
|
||||
MERGE (n:Entity {{uuid: $entity_data.uuid}})
|
||||
SET n:{labels}
|
||||
SET n = $entity_data
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
|
||||
return f"""
|
||||
MERGE (n:Entity {{uuid: $entity_data.uuid}})
|
||||
SET n:{labels}
|
||||
SET n = $entity_data
|
||||
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding)
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
match provider:
|
||||
case GraphProvider.FALKORDB:
|
||||
return f"""
|
||||
MERGE (n:Entity {{uuid: $entity_data.uuid}})
|
||||
SET n:{labels}
|
||||
SET n = $entity_data
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case GraphProvider.NEPTUNE:
|
||||
label_subquery = ''
|
||||
for label in labels.split(':'):
|
||||
label_subquery += f' SET n:{label}\n'
|
||||
return f"""
|
||||
MERGE (n:Entity {{uuid: $entity_data.uuid}})
|
||||
{label_subquery}
|
||||
SET n = removeKeyFromMap(removeKeyFromMap($entity_data, "labels"), "name_embedding")
|
||||
SET n.name_embedding = join([x IN coalesce($entity_data.name_embedding, []) | toString(x) ], ",")
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case _:
|
||||
return f"""
|
||||
MERGE (n:Entity {{uuid: $entity_data.uuid}})
|
||||
SET n:{labels}
|
||||
SET n = $entity_data
|
||||
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding)
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
|
||||
|
||||
def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict]) -> str | Any:
|
||||
if provider == GraphProvider.FALKORDB:
|
||||
queries = []
|
||||
for node in nodes:
|
||||
for label in node['labels']:
|
||||
match provider:
|
||||
case GraphProvider.FALKORDB:
|
||||
queries = []
|
||||
for node in nodes:
|
||||
for label in node['labels']:
|
||||
queries.append(
|
||||
(
|
||||
f"""
|
||||
UNWIND $nodes AS node
|
||||
MERGE (n:Entity {{uuid: node.uuid}})
|
||||
SET n:{label}
|
||||
SET n = node
|
||||
WITH n, node
|
||||
SET n.name_embedding = vecf32(node.name_embedding)
|
||||
RETURN n.uuid AS uuid
|
||||
""",
|
||||
{'nodes': [node]},
|
||||
)
|
||||
)
|
||||
return queries
|
||||
case GraphProvider.NEPTUNE:
|
||||
queries = []
|
||||
for node in nodes:
|
||||
labels = ''
|
||||
for label in node['labels']:
|
||||
labels += f' SET n:{label}\n'
|
||||
queries.append(
|
||||
(
|
||||
f"""
|
||||
f"""
|
||||
UNWIND $nodes AS node
|
||||
MERGE (n:Entity {{uuid: node.uuid}})
|
||||
SET n:{label}
|
||||
SET n = node
|
||||
WITH n, node
|
||||
SET n.name_embedding = vecf32(node.name_embedding)
|
||||
{labels}
|
||||
SET n = removeKeyFromMap(removeKeyFromMap(node, "labels"), "name_embedding")
|
||||
SET n.name_embedding = join([x IN coalesce(node.name_embedding, []) | toString(x) ], ",")
|
||||
RETURN n.uuid AS uuid
|
||||
""",
|
||||
{'nodes': [node]},
|
||||
)
|
||||
"""
|
||||
)
|
||||
return queries
|
||||
|
||||
return """
|
||||
UNWIND $nodes AS node
|
||||
MERGE (n:Entity {uuid: node.uuid})
|
||||
SET n:$(node.labels)
|
||||
SET n = node
|
||||
WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
return queries
|
||||
case _: # Neo4j
|
||||
return """
|
||||
UNWIND $nodes AS node
|
||||
MERGE (n:Entity {uuid: node.uuid})
|
||||
SET n:$(node.labels)
|
||||
SET n = node
|
||||
WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
|
||||
|
||||
ENTITY_NODE_RETURN = """
|
||||
|
|
@ -108,19 +174,27 @@ ENTITY_NODE_RETURN = """
|
|||
|
||||
|
||||
def get_community_node_save_query(provider: GraphProvider) -> str:
|
||||
if provider == GraphProvider.FALKORDB:
|
||||
return """
|
||||
MERGE (n:Community {uuid: $uuid})
|
||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at, name_embedding: vecf32($name_embedding)}
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
|
||||
return """
|
||||
MERGE (n:Community {uuid: $uuid})
|
||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
|
||||
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
match provider:
|
||||
case GraphProvider.FALKORDB:
|
||||
return """
|
||||
MERGE (n:Community {uuid: $uuid})
|
||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at, name_embedding: vecf32($name_embedding)}
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case GraphProvider.NEPTUNE:
|
||||
return """
|
||||
MERGE (n:Community {uuid: $uuid})
|
||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
|
||||
SET n.name_embedding = join([x IN coalesce($name_embedding, []) | toString(x) ], ",")
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
case _: # Neo4j
|
||||
return """
|
||||
MERGE (n:Community {uuid: $uuid})
|
||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
|
||||
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
|
||||
|
||||
COMMUNITY_NODE_RETURN = """
|
||||
|
|
@ -131,3 +205,12 @@ COMMUNITY_NODE_RETURN = """
|
|||
n.summary AS summary,
|
||||
n.created_at AS created_at
|
||||
"""
|
||||
|
||||
COMMUNITY_NODE_RETURN_NEPTUNE = """
|
||||
n.uuid AS uuid,
|
||||
n.name AS name,
|
||||
[x IN split(n.name_embedding, ",") | toFloat(x)] AS name_embedding,
|
||||
n.group_id AS group_id,
|
||||
n.summary AS summary,
|
||||
n.created_at AS created_at
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -31,11 +31,13 @@ from graphiti_core.errors import NodeNotFoundError
|
|||
from graphiti_core.helpers import parse_db_date
|
||||
from graphiti_core.models.nodes.node_db_queries import (
|
||||
COMMUNITY_NODE_RETURN,
|
||||
COMMUNITY_NODE_RETURN_NEPTUNE,
|
||||
ENTITY_NODE_RETURN,
|
||||
EPISODIC_NODE_RETURN,
|
||||
EPISODIC_NODE_SAVE,
|
||||
EPISODIC_NODE_RETURN_NEPTUNE,
|
||||
get_community_node_save_query,
|
||||
get_entity_node_save_query,
|
||||
get_episode_node_save_query,
|
||||
)
|
||||
from graphiti_core.utils.datetime_utils import utc_now
|
||||
|
||||
|
|
@ -89,23 +91,24 @@ class Node(BaseModel, ABC):
|
|||
async def save(self, driver: GraphDriver): ...
|
||||
|
||||
async def delete(self, driver: GraphDriver):
|
||||
if driver.provider == GraphProvider.FALKORDB:
|
||||
for label in ['Entity', 'Episodic', 'Community']:
|
||||
match driver.provider:
|
||||
case GraphProvider.NEO4J:
|
||||
await driver.execute_query(
|
||||
f"""
|
||||
MATCH (n:{label} {{uuid: $uuid}})
|
||||
DETACH DELETE n
|
||||
""",
|
||||
uuid=self.uuid,
|
||||
)
|
||||
else:
|
||||
await driver.execute_query(
|
||||
"""
|
||||
"""
|
||||
MATCH (n:Entity|Episodic|Community {uuid: $uuid})
|
||||
DETACH DELETE n
|
||||
""",
|
||||
uuid=self.uuid,
|
||||
)
|
||||
uuid=self.uuid,
|
||||
)
|
||||
case _: # FalkorDB and Neptune
|
||||
for label in ['Entity', 'Episodic', 'Community']:
|
||||
await driver.execute_query(
|
||||
f"""
|
||||
MATCH (n:{label} {{uuid: $uuid}})
|
||||
DETACH DELETE n
|
||||
""",
|
||||
uuid=self.uuid,
|
||||
)
|
||||
|
||||
logger.debug(f'Deleted Node: {self.uuid}')
|
||||
|
||||
|
|
@ -119,28 +122,30 @@ class Node(BaseModel, ABC):
|
|||
|
||||
@classmethod
|
||||
async def delete_by_group_id(cls, driver: GraphDriver, group_id: str, batch_size: int = 100):
|
||||
if driver.provider == GraphProvider.FALKORDB:
|
||||
for label in ['Entity', 'Episodic', 'Community']:
|
||||
await driver.execute_query(
|
||||
f"""
|
||||
MATCH (n:{label} {{group_id: $group_id}})
|
||||
DETACH DELETE n
|
||||
""",
|
||||
group_id=group_id,
|
||||
)
|
||||
else:
|
||||
async with driver.session() as session:
|
||||
await session.run(
|
||||
"""
|
||||
MATCH (n:Entity|Episodic|Community {group_id: $group_id})
|
||||
CALL {
|
||||
WITH n
|
||||
match driver.provider:
|
||||
case GraphProvider.NEO4J:
|
||||
async with driver.session() as session:
|
||||
await session.run(
|
||||
"""
|
||||
MATCH (n:Entity|Episodic|Community {group_id: $group_id})
|
||||
CALL {
|
||||
WITH n
|
||||
DETACH DELETE n
|
||||
} IN TRANSACTIONS OF $batch_size ROWS
|
||||
""",
|
||||
group_id=group_id,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
case _: # FalkorDB and Neptune
|
||||
for label in ['Entity', 'Episodic', 'Community']:
|
||||
await driver.execute_query(
|
||||
f"""
|
||||
MATCH (n:{label} {{group_id: $group_id}})
|
||||
DETACH DELETE n
|
||||
} IN TRANSACTIONS OF $batch_size ROWS
|
||||
""",
|
||||
group_id=group_id,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
""",
|
||||
group_id=group_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str], batch_size: int = 100):
|
||||
|
|
@ -189,8 +194,21 @@ class EpisodicNode(Node):
|
|||
)
|
||||
|
||||
async def save(self, driver: GraphDriver):
|
||||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
|
||||
'episode_content',
|
||||
[
|
||||
{
|
||||
'uuid': self.uuid,
|
||||
'group_id': self.group_id,
|
||||
'source': self.source.value,
|
||||
'content': self.content,
|
||||
'source_description': self.source_description,
|
||||
}
|
||||
],
|
||||
)
|
||||
result = await driver.execute_query(
|
||||
EPISODIC_NODE_SAVE,
|
||||
get_episode_node_save_query(driver.provider),
|
||||
uuid=self.uuid,
|
||||
name=self.name,
|
||||
group_id=self.group_id,
|
||||
|
|
@ -213,7 +231,11 @@ class EpisodicNode(Node):
|
|||
MATCH (e:Episodic {uuid: $uuid})
|
||||
RETURN
|
||||
"""
|
||||
+ EPISODIC_NODE_RETURN,
|
||||
+ (
|
||||
EPISODIC_NODE_RETURN_NEPTUNE
|
||||
if driver.provider == GraphProvider.NEPTUNE
|
||||
else EPISODIC_NODE_RETURN
|
||||
),
|
||||
uuid=uuid,
|
||||
routing_='r',
|
||||
)
|
||||
|
|
@ -233,7 +255,11 @@ class EpisodicNode(Node):
|
|||
WHERE e.uuid IN $uuids
|
||||
RETURN DISTINCT
|
||||
"""
|
||||
+ EPISODIC_NODE_RETURN,
|
||||
+ (
|
||||
EPISODIC_NODE_RETURN_NEPTUNE
|
||||
if driver.provider == GraphProvider.NEPTUNE
|
||||
else EPISODIC_NODE_RETURN
|
||||
),
|
||||
uuids=uuids,
|
||||
routing_='r',
|
||||
)
|
||||
|
|
@ -262,7 +288,11 @@ class EpisodicNode(Node):
|
|||
+ """
|
||||
RETURN DISTINCT
|
||||
"""
|
||||
+ EPISODIC_NODE_RETURN
|
||||
+ (
|
||||
EPISODIC_NODE_RETURN_NEPTUNE
|
||||
if driver.provider == GraphProvider.NEPTUNE
|
||||
else EPISODIC_NODE_RETURN
|
||||
)
|
||||
+ """
|
||||
ORDER BY uuid DESC
|
||||
"""
|
||||
|
|
@ -284,7 +314,11 @@ class EpisodicNode(Node):
|
|||
MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid})
|
||||
RETURN DISTINCT
|
||||
"""
|
||||
+ EPISODIC_NODE_RETURN,
|
||||
+ (
|
||||
EPISODIC_NODE_RETURN_NEPTUNE
|
||||
if driver.provider == GraphProvider.NEPTUNE
|
||||
else EPISODIC_NODE_RETURN
|
||||
),
|
||||
entity_node_uuid=entity_node_uuid,
|
||||
routing_='r',
|
||||
)
|
||||
|
|
@ -311,11 +345,18 @@ class EntityNode(Node):
|
|||
return self.name_embedding
|
||||
|
||||
async def load_name_embedding(self, driver: GraphDriver):
|
||||
records, _, _ = await driver.execute_query(
|
||||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
query: LiteralString = """
|
||||
MATCH (n:Entity {uuid: $uuid})
|
||||
RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding
|
||||
"""
|
||||
MATCH (n:Entity {uuid: $uuid})
|
||||
RETURN n.name_embedding AS name_embedding
|
||||
""",
|
||||
else:
|
||||
query: LiteralString = """
|
||||
MATCH (n:Entity {uuid: $uuid})
|
||||
RETURN n.name_embedding AS name_embedding
|
||||
"""
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
uuid=self.uuid,
|
||||
routing_='r',
|
||||
)
|
||||
|
|
@ -336,6 +377,9 @@ class EntityNode(Node):
|
|||
}
|
||||
entity_data.update(self.attributes or {})
|
||||
|
||||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
driver.save_to_aoss('node_name_and_summary', [entity_data]) # pyright: ignore reportAttributeAccessIssue
|
||||
|
||||
labels = ':'.join(self.labels + ['Entity'])
|
||||
|
||||
result = await driver.execute_query(
|
||||
|
|
@ -433,8 +477,13 @@ class CommunityNode(Node):
|
|||
summary: str = Field(description='region summary of member nodes', default_factory=str)
|
||||
|
||||
async def save(self, driver: GraphDriver):
|
||||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
|
||||
'community_name',
|
||||
[{'name': self.name, 'uuid': self.uuid, 'group_id': self.group_id}],
|
||||
)
|
||||
result = await driver.execute_query(
|
||||
get_community_node_save_query(driver.provider),
|
||||
get_community_node_save_query(driver.provider), # type: ignore
|
||||
uuid=self.uuid,
|
||||
name=self.name,
|
||||
group_id=self.group_id,
|
||||
|
|
@ -457,11 +506,19 @@ class CommunityNode(Node):
|
|||
return self.name_embedding
|
||||
|
||||
async def load_name_embedding(self, driver: GraphDriver):
|
||||
records, _, _ = await driver.execute_query(
|
||||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
query: LiteralString = """
|
||||
MATCH (c:Community {uuid: $uuid})
|
||||
RETURN [x IN split(c.name_embedding, ",") | toFloat(x)] as name_embedding
|
||||
"""
|
||||
else:
|
||||
query: LiteralString = """
|
||||
MATCH (c:Community {uuid: $uuid})
|
||||
RETURN c.name_embedding AS name_embedding
|
||||
""",
|
||||
"""
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
uuid=self.uuid,
|
||||
routing_='r',
|
||||
)
|
||||
|
|
@ -478,7 +535,11 @@ class CommunityNode(Node):
|
|||
MATCH (n:Community {uuid: $uuid})
|
||||
RETURN
|
||||
"""
|
||||
+ COMMUNITY_NODE_RETURN,
|
||||
+ (
|
||||
COMMUNITY_NODE_RETURN_NEPTUNE
|
||||
if driver.provider == GraphProvider.NEPTUNE
|
||||
else COMMUNITY_NODE_RETURN
|
||||
),
|
||||
uuid=uuid,
|
||||
routing_='r',
|
||||
)
|
||||
|
|
@ -498,7 +559,11 @@ class CommunityNode(Node):
|
|||
WHERE n.uuid IN $uuids
|
||||
RETURN
|
||||
"""
|
||||
+ COMMUNITY_NODE_RETURN,
|
||||
+ (
|
||||
COMMUNITY_NODE_RETURN_NEPTUNE
|
||||
if driver.provider == GraphProvider.NEPTUNE
|
||||
else COMMUNITY_NODE_RETURN
|
||||
),
|
||||
uuids=uuids,
|
||||
routing_='r',
|
||||
)
|
||||
|
|
@ -527,7 +592,11 @@ class CommunityNode(Node):
|
|||
+ """
|
||||
RETURN
|
||||
"""
|
||||
+ COMMUNITY_NODE_RETURN
|
||||
+ (
|
||||
COMMUNITY_NODE_RETURN_NEPTUNE
|
||||
if driver.provider == GraphProvider.NEPTUNE
|
||||
else COMMUNITY_NODE_RETURN
|
||||
)
|
||||
+ """
|
||||
ORDER BY n.uuid DESC
|
||||
"""
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -32,8 +32,8 @@ from graphiti_core.models.edges.edge_db_queries import (
|
|||
get_entity_edge_save_bulk_query,
|
||||
)
|
||||
from graphiti_core.models.nodes.node_db_queries import (
|
||||
EPISODIC_NODE_SAVE_BULK,
|
||||
get_entity_node_save_bulk_query,
|
||||
get_episode_node_save_bulk_query,
|
||||
)
|
||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
|
||||
from graphiti_core.utils.maintenance.edge_operations import (
|
||||
|
|
@ -155,7 +155,7 @@ async def add_nodes_and_edges_bulk_tx(
|
|||
edge_data.update(edge.attributes or {})
|
||||
edges.append(edge_data)
|
||||
|
||||
await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
|
||||
await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
|
||||
entity_node_save_bulk = get_entity_node_save_bulk_query(driver.provider, nodes)
|
||||
await tx.run(entity_node_save_bulk, nodes=nodes)
|
||||
await tx.run(
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from time import time
|
|||
from pydantic import BaseModel
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphProvider
|
||||
from graphiti_core.edges import (
|
||||
CommunityEdge,
|
||||
EntityEdge,
|
||||
|
|
@ -504,23 +504,46 @@ async def resolve_extracted_edge(
|
|||
async def filter_existing_duplicate_of_edges(
|
||||
driver: GraphDriver, duplicates_node_tuples: list[tuple[EntityNode, EntityNode]]
|
||||
) -> list[tuple[EntityNode, EntityNode]]:
|
||||
query: LiteralString = """
|
||||
UNWIND $duplicate_node_uuids AS duplicate_tuple
|
||||
MATCH (n:Entity {uuid: duplicate_tuple[0]})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple[1]})
|
||||
RETURN DISTINCT
|
||||
n.uuid AS source_uuid,
|
||||
m.uuid AS target_uuid
|
||||
"""
|
||||
if not duplicates_node_tuples:
|
||||
return []
|
||||
|
||||
duplicate_nodes_map = {
|
||||
(source.uuid, target.uuid): (source, target) for source, target in duplicates_node_tuples
|
||||
}
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
duplicate_node_uuids=list(duplicate_nodes_map.keys()),
|
||||
routing_='r',
|
||||
)
|
||||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
query: LiteralString = """
|
||||
UNWIND $duplicate_node_uuids AS duplicate_tuple
|
||||
MATCH (n:Entity {uuid: duplicate_tuple.source})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple.target})
|
||||
RETURN DISTINCT
|
||||
n.uuid AS source_uuid,
|
||||
m.uuid AS target_uuid
|
||||
"""
|
||||
|
||||
duplicate_nodes = [
|
||||
{'source': source.uuid, 'target': target.uuid}
|
||||
for source, target in duplicates_node_tuples
|
||||
]
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
duplicate_node_uuids=duplicate_nodes,
|
||||
routing_='r',
|
||||
)
|
||||
else:
|
||||
query: LiteralString = """
|
||||
UNWIND $duplicate_node_uuids AS duplicate_tuple
|
||||
MATCH (n:Entity {uuid: duplicate_tuple[0]})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple[1]})
|
||||
RETURN DISTINCT
|
||||
n.uuid AS source_uuid,
|
||||
m.uuid AS target_uuid
|
||||
"""
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
duplicate_node_uuids=list(duplicate_nodes_map.keys()),
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
# Remove duplicates that already have the IS_DUPLICATE_OF edge
|
||||
for record in records:
|
||||
|
|
|
|||
|
|
@ -19,10 +19,13 @@ from datetime import datetime
|
|||
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphProvider
|
||||
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
|
||||
from graphiti_core.helpers import semaphore_gather
|
||||
from graphiti_core.models.nodes.node_db_queries import EPISODIC_NODE_RETURN
|
||||
from graphiti_core.models.nodes.node_db_queries import (
|
||||
EPISODIC_NODE_RETURN,
|
||||
EPISODIC_NODE_RETURN_NEPTUNE,
|
||||
)
|
||||
from graphiti_core.nodes import EpisodeType, EpisodicNode, get_episodic_node_from_record
|
||||
|
||||
EPISODE_WINDOW_LEN = 3
|
||||
|
|
@ -31,6 +34,8 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bool = False):
|
||||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
return # Neptune does not need indexes built
|
||||
if delete_existing:
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
|
|
@ -71,7 +76,7 @@ async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None):
|
|||
|
||||
async def delete_group_ids(tx):
|
||||
await tx.run(
|
||||
'MATCH (n:Entity|Episodic|Community) WHERE n.group_id IN $group_ids DETACH DELETE n',
|
||||
'MATCH (n) WHERE (n:Entity OR n:Episodic OR n:Community) AND n.group_id IN $group_ids DETACH DELETE n',
|
||||
group_ids=group_ids,
|
||||
)
|
||||
|
||||
|
|
@ -117,7 +122,11 @@ async def retrieve_episodes(
|
|||
+ """
|
||||
RETURN
|
||||
"""
|
||||
+ EPISODIC_NODE_RETURN
|
||||
+ (
|
||||
EPISODIC_NODE_RETURN_NEPTUNE
|
||||
if driver.provider == GraphProvider.NEPTUNE
|
||||
else EPISODIC_NODE_RETURN
|
||||
)
|
||||
+ """
|
||||
ORDER BY e.valid_at DESC
|
||||
LIMIT $num_episodes
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ dependencies = [
|
|||
"tenacity>=9.0.0",
|
||||
"numpy>=1.0.0",
|
||||
"python-dotenv>=1.0.1",
|
||||
"posthog>=3.0.0",
|
||||
"posthog>=3.0.0"
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
|
|
@ -32,6 +32,7 @@ google-genai = ["google-genai>=1.8.0"]
|
|||
falkordb = ["falkordb>=1.1.2,<2.0.0"]
|
||||
voyageai = ["voyageai>=0.2.3"]
|
||||
sentence-transformers = ["sentence-transformers>=3.2.1"]
|
||||
neptune = ["langchain-aws>=0.2.29", "opensearch-py>=3.0.0", "boto3>=1.39.16"]
|
||||
dev = [
|
||||
"pyright>=1.1.380",
|
||||
"groq>=0.2.0",
|
||||
|
|
|
|||
|
|
@ -20,12 +20,14 @@ import pytest
|
|||
from dotenv import load_dotenv
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver
|
||||
from graphiti_core.driver.neptune_driver import NeptuneDriver
|
||||
from graphiti_core.helpers import lucene_sanitize
|
||||
|
||||
load_dotenv()
|
||||
|
||||
HAS_NEO4J = False
|
||||
HAS_FALKORDB = False
|
||||
HAS_NEPTUNE = False
|
||||
if os.getenv('DISABLE_NEO4J') is None:
|
||||
try:
|
||||
from graphiti_core.driver.neo4j_driver import Neo4jDriver
|
||||
|
|
@ -42,6 +44,14 @@ if os.getenv('DISABLE_FALKORDB') is None:
|
|||
except ImportError:
|
||||
pass
|
||||
|
||||
if os.getenv('DISABLE_NEPTUNE') is None:
|
||||
try:
|
||||
from graphiti_core.driver.neptune_driver import NeptuneDriver
|
||||
|
||||
HAS_NEPTUNE = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
NEO4J_URI = os.getenv('NEO4J_URI', 'bolt://localhost:7687')
|
||||
NEO4J_USER = os.getenv('NEO4J_USER', 'neo4j')
|
||||
NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD', 'test')
|
||||
|
|
@ -51,6 +61,10 @@ FALKORDB_PORT = os.getenv('FALKORDB_PORT', '6379')
|
|||
FALKORDB_USER = os.getenv('FALKORDB_USER', None)
|
||||
FALKORDB_PASSWORD = os.getenv('FALKORDB_PASSWORD', None)
|
||||
|
||||
NEPTUNE_HOST = os.getenv('NEPTUNE_HOST', 'localhost')
|
||||
NEPTUNE_PORT = os.getenv('NEPTUNE_PORT', 8182)
|
||||
AOSS_HOST = os.getenv('AOSS_HOST', None)
|
||||
|
||||
|
||||
def get_driver(driver_name: str) -> GraphDriver:
|
||||
if driver_name == 'neo4j':
|
||||
|
|
@ -66,6 +80,12 @@ def get_driver(driver_name: str) -> GraphDriver:
|
|||
username=FALKORDB_USER,
|
||||
password=FALKORDB_PASSWORD,
|
||||
)
|
||||
elif driver_name == 'neptune':
|
||||
return NeptuneDriver(
|
||||
host=NEPTUNE_HOST,
|
||||
port=int(NEPTUNE_PORT),
|
||||
aoss_host=AOSS_HOST,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Driver {driver_name} not available')
|
||||
|
||||
|
|
@ -75,6 +95,8 @@ if HAS_NEO4J:
|
|||
drivers.append('neo4j')
|
||||
if HAS_FALKORDB:
|
||||
drivers.append('falkordb')
|
||||
if HAS_NEPTUNE:
|
||||
drivers.append('neptune')
|
||||
|
||||
|
||||
def test_lucene_sanitize():
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue